mmgp 3.0.0__py3-none-any.whl → 3.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/offload.py CHANGED
@@ -1,1472 +1,1474 @@
1
- # ------------------ Memory Management 3.0 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
- #
3
- # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
- # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
5
- # times in a pipe (eg VAE).
6
- #
7
- # Requirements (for Linux, for Windows systems add 16 GB of RAM):
8
- # - VRAM: minimum 12 GB, recommended 24 GB (RTX 3090/ RTX 4090)
9
- # - RAM: minimum 24 GB, recommended 48 - 64 GB
10
- #
11
- # It is almost plug and play and just needs to be invoked from the main app just after the model pipeline has been created.
12
- # Make sure that the pipeline explictly loads the models in the CPU device
13
- # for instance: pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cpu")
14
- # For a quick setup, you may want to choose between 5 profiles depending on your hardware, for instance:
15
- # from mmgp import offload, profile_type
16
- # offload.profile(pipe, profile_type.HighRAM_LowVRAM_Fast)
17
- # Alternatively you may want to your own parameters, for instance:
18
- # from mmgp import offload
19
- # offload.all(pipe, pinToMemory=true, extraModelsToQuantize = ["text_encoder_2"] )
20
- # The 'transformer' model that contains usually the video or image generator is quantized on the fly by default to 8 bits so that it can fit into 24 GB of VRAM.
21
- # You can prevent the transformer quantization by adding the parameter quantizeTransformer = False
22
- # If you want to save time on disk and reduce the loading time, you may want to load directly a prequantized model. In that case you need to set the option quantizeTransformer to False to turn off on the fly quantization.
23
- # You can specify a list of additional models string ids to quantize (for instance the text_encoder) using the optional argument extraModelsToQuantize. This may be useful if you have less than 48 GB of RAM.
24
- # Note that there is little advantage on the GPU / VRAM side to quantize text encoders as their inputs are usually quite light.
25
- # Conversely if you have more than 48GB RAM you may want to enable RAM pinning with the option pinnedMemory = True. You will get in return super fast loading / unloading of models
26
- # (this can save significant time if the same pipeline is run multiple times in a row)
27
- #
28
- # Sometime there isn't an explicit pipe object as each submodel is loaded separately in the main app. If this is the case, you need to create a dictionary that manually maps all the models.
29
- #
30
- # For instance :
31
- # for flux derived models: pipe = { "text_encoder": clip, "text_encoder_2": t5, "transformer": model, "vae":ae }
32
- # for mochi: pipe = { "text_encoder": self.text_encoder, "transformer": self.dit, "vae":self.decoder }
33
- #
34
- # Please note that there should be always one model whose Id is 'transformer'. It corresponds to the main image / video model which usually needs to be quantized (this is done on the fly by default when loading the model)
35
- #
36
- # Becareful, lots of models use the T5 XXL as a text encoder. However, quite often their corresponding pipeline configurations point at the official Google T5 XXL repository
37
- # where there is a huge 40GB model to download and load. It is cumbersorme as it is a 32 bits model and contains the decoder part of T5 that is not used.
38
- # I suggest you use instead one of the 16 bits encoder only version available around, for instance:
39
- # text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2", torch_dtype=torch.float16)
40
- #
41
- # Sometime just providing the pipe won't be sufficient as you will need to change the content of the core model:
42
- # - For instance you may need to disable an existing CPU offload logic that already exists (such as manual calls to move tensors between cuda and the cpu)
43
- # - mmpg to tries to fake the device as being "cuda" but sometimes some code won't be fooled and it will create tensors in the cpu device and this may cause some issues.
44
- #
45
- # You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
46
- #
47
- # Thanks to
48
- # ---------
49
- # Huggingface / accelerate for the hooking examples
50
- # Huggingface / quanto for their very useful quantizer
51
- # gau-nernst for his Pinnig RAM samples
52
-
53
-
54
- #
55
-
56
- import torch
57
- import gc
58
- import time
59
- import functools
60
- import sys
61
- import os
62
- import json
63
- import psutil
64
- from mmgp import safetensors2
65
- from mmgp import profile_type
66
-
67
- from optimum.quanto import freeze, qfloat8, qint8, quantize, QModuleMixin, QTensor, WeightQBytesTensor, quantize_module
68
-
69
-
70
-
71
-
72
- mmm = safetensors2.mmm
73
-
74
- ONE_MB = 1048576
75
- sizeofbfloat16 = torch.bfloat16.itemsize
76
- sizeofint8 = torch.int8.itemsize
77
- total_pinned_bytes = 0
78
- physical_memory= psutil.virtual_memory().total
79
-
80
- HEADER = '\033[95m'
81
- ENDC = '\033[0m'
82
- BOLD ='\033[1m'
83
- UNBOLD ='\033[0m'
84
-
85
- cotenants_map = {
86
- "text_encoder": ["vae", "text_encoder_2"],
87
- "text_encoder_2": ["vae", "text_encoder"],
88
- }
89
-
90
- class clock:
91
- def __init__(self):
92
- self.start_time = 0
93
- self.end_time = 0
94
-
95
- @classmethod
96
- def start(cls):
97
- self = cls()
98
- self.start_time =time.time()
99
- return self
100
-
101
- def stop(self):
102
- self.stop_time =time.time()
103
-
104
- def time_gap(self):
105
- return self.stop_time - self.start_time
106
-
107
- def format_time_gap(self):
108
- return f"{self.stop_time - self.start_time:.2f}s"
109
-
110
-
111
-
112
- # useful functions to move a group of tensors (to design custom offload patches)
113
- def move_tensors(obj, device):
114
- if torch.is_tensor(obj):
115
- return obj.to(device)
116
- elif isinstance(obj, dict):
117
- _dict = {}
118
- for k, v in obj.items():
119
- _dict[k] = move_tensors(v, device)
120
- return _dict
121
- elif isinstance(obj, list):
122
- _list = []
123
- for v in obj:
124
- _list.append(move_tensors(v, device))
125
- return _list
126
- else:
127
- raise TypeError("Tensor or list / dict of tensors expected")
128
-
129
-
130
- def _get_max_reservable_memory(perc_reserved_mem_max):
131
- if perc_reserved_mem_max<=0:
132
- perc_reserved_mem_max = 0.40 if os.name == 'nt' else 0.5
133
- return perc_reserved_mem_max * physical_memory
134
-
135
- def _detect_main_towers(model, verboseLevel=1):
136
- cur_blocks_prefix = None
137
- towers_modules= []
138
- towers_names= []
139
-
140
- for submodule_name, submodule in model.named_modules():
141
- if submodule_name=='':
142
- continue
143
-
144
- if isinstance(submodule, torch.nn.ModuleList):
145
- newList =False
146
- if cur_blocks_prefix == None:
147
- cur_blocks_prefix = submodule_name + "."
148
- newList = True
149
- else:
150
- if not submodule_name.startswith(cur_blocks_prefix):
151
- cur_blocks_prefix = submodule_name + "."
152
- newList = True
153
-
154
- if newList and len(submodule)>=5:
155
- towers_names.append(submodule_name)
156
- towers_modules.append(submodule)
157
-
158
- else:
159
- if cur_blocks_prefix is not None:
160
- if not submodule_name.startswith(cur_blocks_prefix):
161
- cur_blocks_prefix = None
162
-
163
- return towers_names, towers_modules
164
-
165
-
166
-
167
- def _get_model(model_path):
168
- if os.path.isfile(model_path):
169
- return model_path
170
-
171
- from pathlib import Path
172
- _path = Path(model_path).parts
173
- _filename = _path[-1]
174
- _path = _path[:-1]
175
- if len(_path)==1:
176
- raise("file not found")
177
- else:
178
- from huggingface_hub import hf_hub_download #snapshot_download,
179
- repoId= os.path.join(*_path[0:2] ).replace("\\", "/")
180
-
181
- if len(_path) > 2:
182
- _subfolder = os.path.join(*_path[2:] )
183
- model_path = hf_hub_download(repo_id=repoId, filename=_filename, subfolder=_subfolder)
184
- else:
185
- model_path = hf_hub_download(repo_id=repoId, filename=_filename)
186
-
187
- return model_path
188
-
189
-
190
-
191
- def _remove_model_wrapper(model):
192
- if not model._modules is None:
193
- if len(model._modules)!=1:
194
- return model
195
- sub_module = model._modules[next(iter(model._modules))]
196
- if hasattr(sub_module,"config") or hasattr(sub_module,"base_model"):
197
- return sub_module
198
- return model
199
-
200
-
201
-
202
- def _move_to_pinned_tensor(source_tensor, big_tensor, offset, length):
203
- dtype= source_tensor.dtype
204
- shape = source_tensor.shape
205
- if len(shape) == 0:
206
- return source_tensor
207
- else:
208
- t = source_tensor.view(torch.uint8)
209
- t = torch.reshape(t, (length,))
210
- # magic swap !
211
- big_tensor[offset: offset + length] = t
212
- t = big_tensor[offset: offset + length]
213
- t = t.view(dtype)
214
- t = torch.reshape(t, shape)
215
- assert t.is_pinned()
216
- return t
217
-
218
- def _safetensors_load_file(file_path):
219
- from collections import OrderedDict
220
- sd = OrderedDict()
221
-
222
- with safetensors2.safe_open(file_path, framework="pt", device="cpu") as f:
223
- for k in f.keys():
224
- sd[k] = f.get_tensor(k)
225
- metadata = f.metadata()
226
-
227
- return sd, metadata
228
-
229
- def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_max = 0, verboseLevel = 1):
230
- if verboseLevel>=1 :
231
- if partialPinning:
232
- print(f"Partial pinning to RAM of data of '{model_id}'")
233
- else:
234
- print(f"Pinning data to RAM of '{model_id}'")
235
-
236
- max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
237
- if partialPinning:
238
- towers_names, _ = _detect_main_towers(model)
239
- towers_names = [n +"." for n in towers_names]
240
-
241
- BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
242
- current_big_tensor_size = 0
243
- big_tensor_no = 0
244
- big_tensors_sizes = []
245
- tensor_map_indexes = []
246
- total_tensor_bytes = 0
247
-
248
- params_list = []
249
- for k, sub_module in model.named_modules():
250
- include = True
251
- if partialPinning:
252
- include = any(k.startswith(pre) for pre in towers_names) if partialPinning else True
253
- if include:
254
- params_list = params_list + list(sub_module.buffers(recurse=False)) + list(sub_module.parameters(recurse=False))
255
-
256
- for p in params_list:
257
- if isinstance(p, QTensor):
258
- length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
259
- else:
260
- length = torch.numel(p.data) * p.data.element_size()
261
-
262
- if current_big_tensor_size + length > BIG_TENSOR_MAX_SIZE:
263
- big_tensors_sizes.append(current_big_tensor_size)
264
- current_big_tensor_size = 0
265
- big_tensor_no += 1
266
- tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
267
- current_big_tensor_size += length
268
-
269
- total_tensor_bytes += length
270
-
271
-
272
- big_tensors_sizes.append(current_big_tensor_size)
273
-
274
- big_tensors = []
275
- last_big_tensor = 0
276
- total = 0
277
-
278
-
279
-
280
- for size in big_tensors_sizes:
281
- try:
282
- current_big_tensor = torch.empty( size, dtype= torch.uint8, pin_memory=True, device="cpu")
283
- big_tensors.append(current_big_tensor)
284
- except:
285
- print(f"Unable to pin more tensors for this model as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
286
- break
287
-
288
- last_big_tensor += 1
289
- total += size
290
-
291
-
292
- gc.collect()
293
-
294
- tensor_no = 0
295
- for p in params_list:
296
- big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
297
-
298
- if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
299
- current_big_tensor = big_tensors[big_tensor_no]
300
- if isinstance(p, QTensor):
301
- length1 = torch.numel(p._data) * p._data.element_size()
302
- p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
303
- length2 = torch.numel(p._scale) * p._scale.element_size()
304
- p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
305
- else:
306
- length = torch.numel(p.data) * p.data.element_size()
307
- p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
308
-
309
- tensor_no += 1
310
- global total_pinned_bytes
311
- total_pinned_bytes += total
312
- gc.collect()
313
-
314
- if verboseLevel >=1:
315
- if total_tensor_bytes == total:
316
- print(f"The whole model was pinned to RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
317
- else:
318
- print(f"{total/ONE_MB:.2f} MB were pinned to RAM out of {total_tensor_bytes/ONE_MB:.2f} MB")
319
-
320
- model._already_pinned = True
321
-
322
-
323
- return
324
- welcome_displayed = False
325
-
326
- def _welcome():
327
- global welcome_displayed
328
- if welcome_displayed:
329
- return
330
- welcome_displayed = True
331
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.0) by DeepBeepMeep ************{ENDC}{UNBOLD}")
332
-
333
-
334
- # def _pin_to_memory_sd(model, sd, model_id, partialPinning = False, perc_reserved_mem_max = 0, verboseLevel = 1):
335
- # if verboseLevel>=1 :
336
- # if partialPinning:
337
- # print(f"Partial pinning to RAM of data of file '{model_id}' while loading it")
338
- # else:
339
- # print(f"Pinning data to RAM of file '{model_id}' while loading it")
340
-
341
- # max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
342
- # if partialPinning:
343
- # towers_names, _ = _detect_main_towers(model)
344
- # towers_names = [n +"." for n in towers_names]
345
-
346
- # BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
347
- # current_big_tensor_size = 0
348
- # big_tensor_no = 0
349
- # big_tensors_sizes = []
350
- # tensor_map_indexes = []
351
- # total_tensor_bytes = 0
352
-
353
- # for k,t in sd.items():
354
- # include = True
355
- # # if isinstance(p, QTensor):
356
- # # length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
357
- # # else:
358
- # # length = torch.numel(p.data) * p.data.element_size()
359
- # length = torch.numel(t) * t.data.element_size()
360
-
361
- # if partialPinning:
362
- # include = any(k.startswith(pre) for pre in towers_names) if partialPinning else True
363
-
364
- # if include:
365
- # if current_big_tensor_size + length > BIG_TENSOR_MAX_SIZE:
366
- # big_tensors_sizes.append(current_big_tensor_size)
367
- # current_big_tensor_size = 0
368
- # big_tensor_no += 1
369
- # tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
370
- # current_big_tensor_size += length
371
- # else:
372
- # tensor_map_indexes.append((-1, 0, 0 ))
373
- # total_tensor_bytes += length
374
-
375
- # big_tensors_sizes.append(current_big_tensor_size)
376
-
377
- # big_tensors = []
378
- # last_big_tensor = 0
379
- # total = 0
380
-
381
-
382
- # for size in big_tensors_sizes:
383
- # try:
384
- # currrent_big_tensor = torch.empty( size, dtype= torch.uint8, pin_memory=True)
385
- # big_tensors.append(currrent_big_tensor)
386
- # except:
387
- # print(f"Unable to pin more tensors for this model as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
388
- # break
389
-
390
- # last_big_tensor += 1
391
- # total += size
392
-
393
-
394
- # tensor_no = 0
395
- # for k,t in sd.items():
396
- # big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
397
- # if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
398
- # current_big_tensor = big_tensors[big_tensor_no]
399
- # # if isinstance(p, QTensor):
400
- # # length1 = torch.numel(p._data) * p._data.element_size()
401
- # # p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
402
- # # length2 = torch.numel(p._scale) * p._scale.element_size()
403
- # # p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
404
- # # else:
405
- # # length = torch.numel(p.data) * p.data.element_size()
406
- # # p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
407
- # length = torch.numel(t) * t.data.element_size()
408
- # t = _move_to_pinned_tensor(t, current_big_tensor, offset, length)
409
- # sd[k] = t
410
- # tensor_no += 1
411
-
412
- # global total_pinned_bytes
413
- # total_pinned_bytes += total
414
-
415
- # if verboseLevel >=1:
416
- # if total_tensor_bytes == total:
417
- # print(f"The whole model was pinned to RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
418
- # else:
419
- # print(f"{total/ONE_MB:.2f} MB were pinned to RAM out of {total_tensor_bytes/ONE_MB:.2f} MB")
420
-
421
- # model._already_pinned = True
422
-
423
-
424
- # return
425
-
426
- def _quantize_dirty_hack(model):
427
- # dirty hack: add a hook on state_dict() to return a fake non quantized state_dict if called by Lora Diffusers initialization functions
428
- setattr( model, "_real_state_dict", model.state_dict)
429
- from collections import OrderedDict
430
- import traceback
431
-
432
- def state_dict_for_lora(self):
433
- real_sd = self._real_state_dict()
434
- fakeit = False
435
- stack = traceback.extract_stack(f=None, limit=5)
436
- for frame in stack:
437
- if "_lora_" in frame.name:
438
- fakeit = True
439
- break
440
-
441
- if not fakeit:
442
- return real_sd
443
- sd = OrderedDict()
444
- for k in real_sd:
445
- v = real_sd[k]
446
- if k.endswith("._data"):
447
- k = k[:len(k)-6]
448
- sd[k] = v
449
- return sd
450
-
451
- setattr(model, "state_dict", functools.update_wrapper(functools.partial(state_dict_for_lora, model), model.state_dict) )
452
-
453
- def _quantization_map(model):
454
- from optimum.quanto import quantization_map
455
- return quantization_map(model)
456
-
457
- def _set_module_by_name(parent_module, name, child_module):
458
- module_names = name.split(".")
459
- if len(module_names) == 1:
460
- setattr(parent_module, name, child_module)
461
- else:
462
- parent_module_name = name[: name.rindex(".")]
463
- parent_module = parent_module.get_submodule(parent_module_name)
464
- setattr(parent_module, module_names[-1], child_module)
465
-
466
- def _quantize_submodule(
467
- model: torch.nn.Module,
468
- name: str,
469
- module: torch.nn.Module,
470
- weights = None,
471
- activations = None,
472
- optimizer = None,
473
- ):
474
-
475
- qmodule = quantize_module(module, weights=weights, activations=activations, optimizer=optimizer)
476
- if qmodule is not None:
477
- _set_module_by_name(model, name, qmodule)
478
- qmodule.name = name
479
- for name, param in module.named_parameters():
480
- # Save device memory by clearing parameters
481
- setattr(module, name, None)
482
- del param
483
-
484
- def _requantize(model: torch.nn.Module, state_dict: dict, quantization_map: dict):
485
- # change dtype of current meta model parameters because 'requantize' won't update the dtype on non quantized parameters
486
- for k, p in model.named_parameters():
487
- if not k in quantization_map and k in state_dict:
488
- p_in_file = state_dict[k]
489
- if p.data.dtype != p_in_file.data.dtype:
490
- p.data = p.data.to(p_in_file.data.dtype)
491
-
492
- # rebuild quanto objects
493
- for name, m in model.named_modules():
494
- qconfig = quantization_map.get(name, None)
495
- if qconfig is not None:
496
- weights = qconfig["weights"]
497
- if weights == "none":
498
- weights = None
499
- activations = qconfig["activations"]
500
- if activations == "none":
501
- activations = None
502
- _quantize_submodule(model, name, m, weights=weights, activations=activations)
503
-
504
- model._quanto_map = quantization_map
505
-
506
- _quantize_dirty_hack(model)
507
-
508
-
509
-
510
- def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 1000000000, model_id = 'Unknown'):
511
-
512
- def compute_submodule_size(submodule):
513
- size = 0
514
- for p in submodule.parameters(recurse=False):
515
- size += torch.numel(p.data) * sizeofbfloat16
516
-
517
- for p in submodule.buffers(recurse=False):
518
- size += torch.numel(p.data) * sizeofbfloat16
519
-
520
- return size
521
-
522
- total_size =0
523
- total_excluded = 0
524
- exclude_list = []
525
- submodule_size = 0
526
- submodule_names = []
527
- cur_blocks_prefix = None
528
- prev_blocks_prefix = None
529
-
530
- if hasattr(model_to_quantize, "_quanto_map"):
531
- print(f"Model '{model_id}' is already quantized")
532
- return False
533
-
534
- print(f"Quantization of model '{model_id}' started")
535
-
536
- for submodule_name, submodule in model_to_quantize.named_modules():
537
- if isinstance(submodule, QModuleMixin):
538
- if verboseLevel>=1:
539
- print("No quantization to do as model is already quantized")
540
- return False
541
-
542
-
543
- if submodule_name=='':
544
- continue
545
-
546
-
547
- flush = False
548
- if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
549
- if cur_blocks_prefix == None:
550
- cur_blocks_prefix = submodule_name + "."
551
- flush = True
552
- else:
553
- #if cur_blocks_prefix != submodule_name[:len(cur_blocks_prefix)]:
554
- if not submodule_name.startswith(cur_blocks_prefix):
555
- cur_blocks_prefix = submodule_name + "."
556
- flush = True
557
- else:
558
- if cur_blocks_prefix is not None:
559
- #if not cur_blocks_prefix == submodule_name[0:len(cur_blocks_prefix)]:
560
- if not submodule_name.startswith(cur_blocks_prefix):
561
- cur_blocks_prefix = None
562
- flush = True
563
-
564
- if flush:
565
- if submodule_size <= threshold:
566
- exclude_list += submodule_names
567
- if verboseLevel >=2:
568
- print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
569
- total_excluded += submodule_size
570
-
571
- submodule_size = 0
572
- submodule_names = []
573
- prev_blocks_prefix = cur_blocks_prefix
574
- size = compute_submodule_size(submodule)
575
- submodule_size += size
576
- total_size += size
577
- submodule_names.append(submodule_name)
578
-
579
- if submodule_size > 0 and submodule_size <= threshold:
580
- exclude_list += submodule_names
581
- if verboseLevel >=2:
582
- print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
583
- total_excluded += submodule_size
584
-
585
- perc_excluded =total_excluded/ total_size if total_size >0 else 1
586
- if verboseLevel >=2:
587
- print(f"Total Excluded {total_excluded/ONE_MB:.1f} MB oF {total_size/ONE_MB:.1f} that is {perc_excluded*100:.2f}%")
588
- if perc_excluded >= 0.10:
589
- print(f"Too many many modules are excluded, there is something wrong with the selection, switch back to full quantization.")
590
- exclude_list = None
591
-
592
-
593
- #quantize(model_to_quantize,weights, exclude= exclude_list)
594
- pass
595
- for name, m in model_to_quantize.named_modules():
596
- if exclude_list is None or not any( name == module_name for module_name in exclude_list):
597
- _quantize_submodule(model_to_quantize, name, m, weights=weights, activations=None, optimizer=None)
598
-
599
- # force read non quantized parameters so that their lazy tensors and corresponding mmap are released
600
- # otherwise we may end up to keep in memory both the quantized and the non quantize model
601
-
602
-
603
- for name, m in model_to_quantize.named_modules():
604
- # do not read quantized weights (detected them directly or behind an adapter)
605
- if isinstance(m, QModuleMixin) or hasattr(m, "base_layer") and isinstance(m.base_layer, QModuleMixin):
606
- pass
607
- else:
608
- if hasattr(m, "weight") and m.weight is not None:
609
- m.weight.data = m.weight.data + 0
610
-
611
- if hasattr(m, "bias") and m.bias is not None:
612
- m.bias.data = m.bias.data + 0
613
-
614
-
615
- freeze(model_to_quantize)
616
- torch.cuda.empty_cache()
617
- gc.collect()
618
- quantization_map = _quantization_map(model_to_quantize)
619
- model_to_quantize._quanto_map = quantization_map
620
-
621
- _quantize_dirty_hack(model_to_quantize)
622
-
623
- print(f"Quantization of model '{model_id}' done")
624
-
625
- return True
626
-
627
- def get_model_name(model):
628
- return model.name
629
-
630
- class HfHook:
631
- def __init__(self):
632
- self.execution_device = "cuda"
633
-
634
- def detach_hook(self, module):
635
- pass
636
-
637
- class offload:
638
- def __init__(self):
639
- self.active_models = []
640
- self.active_models_ids = []
641
- self.active_subcaches = {}
642
- self.models = {}
643
- self.verboseLevel = 0
644
- self.modules_data = {}
645
- self.blocks_of_modules = {}
646
- self.blocks_of_modules_sizes = {}
647
- self.anyCompiledModule = False
648
- self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
649
- self.last_reserved_mem_check =0
650
- self.loaded_blocks = {}
651
- self.prev_blocks_names = {}
652
- self.next_blocks_names = {}
653
- self.default_stream = torch.cuda.default_stream(torch.device("cuda")) # torch.cuda.current_stream()
654
- self.transfer_stream = torch.cuda.Stream()
655
- self.async_transfers = False
656
-
657
- def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name):
658
-
659
- entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
660
- if entry_name in self.blocks_of_modules:
661
- blocks_params = self.blocks_of_modules[entry_name]
662
- blocks_params_size = self.blocks_of_modules_sizes[entry_name]
663
- else:
664
- blocks_params = []
665
- self.blocks_of_modules[entry_name] = blocks_params
666
- blocks_params_size = 0
667
- if blocks_name !=None:
668
-
669
- prev_entry_name = None if prev_block_name == None else model_id + "/" + prev_block_name
670
- self.prev_blocks_names[entry_name] = prev_entry_name
671
- if not prev_block_name == None:
672
- self.next_blocks_names[prev_entry_name] = entry_name
673
-
674
-
675
- for k,p in submodule.named_parameters(recurse=False):
676
- blocks_params.append(p)
677
- if isinstance(p, QTensor):
678
- blocks_params_size += p._data.nbytes
679
- blocks_params_size += p._scale.nbytes
680
- else:
681
- blocks_params_size += p.data.nbytes
682
-
683
- for p in submodule.buffers(recurse=False):
684
- blocks_params.append(p)
685
- blocks_params_size += p.data.nbytes
686
-
687
-
688
- self.blocks_of_modules_sizes[entry_name] = blocks_params_size
689
-
690
- return blocks_params_size
691
-
692
-
693
- def can_model_be_cotenant(self, model_id):
694
- potential_cotenants= cotenants_map.get(model_id, None)
695
- if potential_cotenants is None:
696
- return False
697
- for existing_cotenant in self.active_models_ids:
698
- if existing_cotenant not in potential_cotenants:
699
- return False
700
- return True
701
-
702
- def gpu_load_blocks(self, model_id, blocks_name, async_load = False):
703
- # cl = clock.start()
704
- import weakref
705
-
706
- if blocks_name != None:
707
- self.loaded_blocks[model_id] = blocks_name
708
-
709
- def cpu_to_gpu(stream_to_use, blocks_params, record_for_stream = None):
710
- with torch.cuda.stream(stream_to_use):
711
- for p in blocks_params:
712
- if isinstance(p, QTensor):
713
- # need formal transfer to cuda otherwise quantized tensor will be still considered in cpu and compilation will fail
714
- q=p.to("cuda",non_blocking=True)
715
- #q = torch.nn.Parameter(q , requires_grad=False)
716
-
717
- ref = weakref.getweakrefs(p)
718
- if ref:
719
- torch._C._swap_tensor_impl(p, q)
720
- else:
721
- torch.utils.swap_tensors(p, q)
722
-
723
- # p._data = p._data.cuda(non_blocking=True)
724
- # p._scale = p._scale.cuda(non_blocking=True)
725
- else:
726
- p.data = p.data.cuda(non_blocking=True)
727
-
728
- if record_for_stream != None:
729
- if isinstance(p, QTensor):
730
- p._data.record_stream(record_for_stream)
731
- p._scale.record_stream(record_for_stream)
732
- else:
733
- p.data.record_stream(record_for_stream)
734
-
735
-
736
- entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
737
- if self.verboseLevel >=2:
738
- model = self.models[model_id]
739
- model_name = model._get_name()
740
- print(f"Loading model {entry_name} ({model_name}) in GPU")
741
-
742
-
743
- if self.async_transfers and blocks_name != None:
744
- first = self.prev_blocks_names[entry_name] == None
745
- next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None
746
- if first:
747
- cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name])
748
- torch.cuda.synchronize()
749
-
750
- if next_blocks_entry != None:
751
- cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry]) #, self.default_stream
752
-
753
- else:
754
- cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name])
755
- torch.cuda.synchronize()
756
- # cl.stop()
757
- # print(f"load time: {cl.format_time_gap()}")
758
-
759
-
760
- def gpu_unload_blocks(self, model_id, blocks_name):
761
- # cl = clock.start()
762
- import weakref
763
- if blocks_name != None:
764
- self.loaded_blocks[model_id] = None
765
-
766
- blocks_name = model_id if blocks_name is None else model_id + "/" + blocks_name
767
-
768
- if self.verboseLevel >=2:
769
- model = self.models[model_id]
770
- model_name = model._get_name()
771
- print(f"Unloading model {blocks_name} ({model_name}) from GPU")
772
-
773
- blocks_params = self.blocks_of_modules[blocks_name]
774
- if "transformer/double_blocks.0" == blocks_name:
775
- pass
776
- parameters_data = self.modules_data[model_id]
777
- for p in blocks_params:
778
- if isinstance(p, QTensor):
779
- data = parameters_data[p]
780
-
781
- # needs to create a new WeightQBytesTensor with the cached data that is in the cpu device like the data (faketensor p is still in cuda)
782
- q = WeightQBytesTensor.create(p.qtype, p.axis, p.size(), p.stride(), data[0], data[1], activation_qtype=p.activation_qtype, requires_grad=p.requires_grad )
783
- #q = torch.nn.Parameter(q , requires_grad=False)
784
- ref = weakref.getweakrefs(p)
785
- if ref:
786
- torch._C._swap_tensor_impl(p, q)
787
- else:
788
- torch.utils.swap_tensors(p, q)
789
-
790
- # p._data = data[0]
791
- # p._scale = data[1]
792
- else:
793
- p.data = parameters_data[p]
794
- # cl.stop()
795
- # print(f"unload time: {cl.format_time_gap()}")
796
-
797
-
798
- def gpu_load(self, model_id):
799
- model = self.models[model_id]
800
- self.active_models.append(model)
801
- self.active_models_ids.append(model_id)
802
-
803
- self.gpu_load_blocks(model_id, None)
804
-
805
- # torch.cuda.current_stream().synchronize()
806
-
807
- def unload_all(self):
808
- for model_id in self.active_models_ids:
809
- self.gpu_unload_blocks(model_id, None)
810
- loaded_block = self.loaded_blocks[model_id]
811
- if loaded_block != None:
812
- self.gpu_unload_blocks(model_id, loaded_block)
813
- self.loaded_blocks[model_id] = None
814
-
815
- self.active_models = []
816
- self.active_models_ids = []
817
- self.active_subcaches = []
818
- torch.cuda.empty_cache()
819
- gc.collect()
820
- self.last_reserved_mem_check = time.time()
821
-
822
- def move_args_to_gpu(self, *args, **kwargs):
823
- new_args= []
824
- new_kwargs={}
825
- for arg in args:
826
- if torch.is_tensor(arg):
827
- if arg.dtype == torch.float32:
828
- arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
829
- else:
830
- arg = arg.cuda(non_blocking=True)
831
- new_args.append(arg)
832
-
833
- for k in kwargs:
834
- arg = kwargs[k]
835
- if torch.is_tensor(arg):
836
- if arg.dtype == torch.float32:
837
- arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
838
- else:
839
- arg = arg.cuda(non_blocking=True)
840
- new_kwargs[k]= arg
841
-
842
- return new_args, new_kwargs
843
-
844
- def ready_to_check_mem(self):
845
- if self.anyCompiledModule:
846
- return
847
- cur_clock = time.time()
848
- # can't check at each call if we can empty the cuda cache as quering the reserved memory value is a time consuming operation
849
- if (cur_clock - self.last_reserved_mem_check)<0.200:
850
- return False
851
- self.last_reserved_mem_check = cur_clock
852
- return True
853
-
854
-
855
- def empty_cache_if_needed(self):
856
- mem_reserved = torch.cuda.memory_reserved()
857
- mem_threshold = 0.9*self.device_mem_capacity
858
- if mem_reserved >= mem_threshold:
859
- mem_allocated = torch.cuda.memory_allocated()
860
- if mem_allocated <= 0.70 * mem_reserved:
861
- # print(f"Cuda empty cache triggered as Allocated Memory ({mem_allocated/1024000:0f} MB) is lot less than Cached Memory ({mem_reserved/1024000:0f} MB) ")
862
- torch.cuda.empty_cache()
863
- tm= time.time()
864
- if self.verboseLevel >=2:
865
- print(f"Empty Cuda cache at {tm}")
866
- # print(f"New cached memory after purge is {torch.cuda.memory_reserved()/1024000:0f} MB) ")
867
-
868
-
869
- def any_param_or_buffer(self, target_module: torch.nn.Module):
870
-
871
- for _ in target_module.parameters(recurse= False):
872
- return True
873
-
874
- for _ in target_module.buffers(recurse= False):
875
- return True
876
-
877
- return False
878
-
879
- def hook_load_data_if_needed(self, target_module, model_id,blocks_name, context):
880
-
881
- @torch.compiler.disable()
882
- def load_data_if_needed(module, *args, **kwargs):
883
- some_context = context #for debugging
884
- if blocks_name == None:
885
- if self.ready_to_check_mem():
886
- self.empty_cache_if_needed()
887
- else:
888
- loaded_block = self.loaded_blocks[model_id]
889
- if (loaded_block == None or loaded_block != blocks_name) :
890
- if loaded_block != None:
891
- self.gpu_unload_blocks(model_id, loaded_block)
892
- if self.ready_to_check_mem():
893
- self.empty_cache_if_needed()
894
- self.loaded_blocks[model_id] = blocks_name
895
- self.gpu_load_blocks(model_id, blocks_name)
896
-
897
- target_module.register_forward_pre_hook(load_data_if_needed)
898
-
899
-
900
- def hook_check_empty_cache_needed(self, target_module, model_id,blocks_name, previous_method, context):
901
-
902
- def check_empty_cuda_cache(module, *args, **kwargs):
903
- # if self.ready_to_check_mem():
904
- # self.empty_cache_if_needed()
905
- if blocks_name == None:
906
- if self.ready_to_check_mem():
907
- self.empty_cache_if_needed()
908
- else:
909
- loaded_block = self.loaded_blocks[model_id]
910
- if (loaded_block == None or loaded_block != blocks_name) :
911
- if loaded_block != None:
912
- self.gpu_unload_blocks(model_id, loaded_block)
913
- if self.ready_to_check_mem():
914
- self.empty_cache_if_needed()
915
- self.loaded_blocks[model_id] = blocks_name
916
- self.gpu_load_blocks(model_id, blocks_name)
917
-
918
- return previous_method(*args, **kwargs)
919
-
920
-
921
- if hasattr(target_module, "_mm_id"):
922
- orig_model_id = getattr(target_module, "_mm_id")
923
- if self.verboseLevel >=2:
924
- print(f"Model '{model_id}' shares module '{target_module._get_name()}' with module '{orig_model_id}' ")
925
- assert not self.any_param_or_buffer(target_module)
926
-
927
- return
928
- setattr(target_module, "_mm_id", model_id)
929
- setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method) )
930
-
931
-
932
- def hook_change_module(self, target_module, model, model_id, module_id, previous_method):
933
- def check_change_module(module, *args, **kwargs):
934
- performEmptyCacheTest = False
935
- if not model_id in self.active_models_ids:
936
- new_model_id = getattr(module, "_mm_id")
937
- # do not always unload existing models if it is more efficient to keep in them in the GPU
938
- # (e.g: small modules whose calls are text encoders)
939
- if not self.can_model_be_cotenant(new_model_id) :
940
- self.unload_all()
941
- performEmptyCacheTest = False
942
- self.gpu_load(new_model_id)
943
- # transfer leftovers inputs that were incorrectly created in the RAM (mostly due to some .device tests that returned incorrectly "cpu")
944
- args, kwargs = self.move_args_to_gpu(*args, **kwargs)
945
- if performEmptyCacheTest:
946
- self.empty_cache_if_needed()
947
-
948
- return previous_method(*args, **kwargs)
949
-
950
- if hasattr(target_module, "_mm_id"):
951
- return
952
- setattr(target_module, "_mm_id", model_id)
953
-
954
- setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_change_module, target_module), previous_method) )
955
-
956
- if not self.verboseLevel >=1:
957
- return
958
-
959
- if module_id == None or module_id =='':
960
- model_name = model._get_name()
961
- print(f"Hooked in model '{model_id}' ({model_name})")
962
-
963
-
964
- # Not implemented yet, but why would one want to get rid of these features ?
965
- # def unhook_module(module: torch.nn.Module):
966
- # if not hasattr(module,"_mm_id"):
967
- # return
968
-
969
- # delattr(module, "_mm_id")
970
-
971
- # def unhook_all(parent_module: torch.nn.Module):
972
- # for module in parent_module.components.items():
973
- # self.unhook_module(module)
974
-
975
- def fast_load_transformers_model(model_path: str, do_quantize = False, quantization_type = qint8, pinToMemory = False, partialPinning = False, verbose_level = 1):
976
- """
977
- quick version of .LoadfromPretrained of the transformers library
978
- used to build a model and load the corresponding weights (quantized or not)
979
- """
980
- import os.path
981
- from accelerate import init_empty_weights
982
-
983
- if not (model_path.endswith(".sft") or model_path.endswith(".safetensors")):
984
- raise Exception("full model path to file expected")
985
-
986
- model_path = _get_model(model_path)
987
-
988
- with safetensors2.safe_open(model_path) as f:
989
- metadata = f.metadata()
990
-
991
- if metadata is None:
992
- transformer_config = None
993
- else:
994
- transformer_config = metadata.get("config", None)
995
-
996
- if transformer_config == None:
997
- config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
998
-
999
- if not os.path.isfile(config_fullpath):
1000
- raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
1001
-
1002
- with open(config_fullpath, "r", encoding="utf-8") as reader:
1003
- text = reader.read()
1004
- transformer_config= json.loads(text)
1005
-
1006
-
1007
- if "architectures" in transformer_config:
1008
- architectures = transformer_config["architectures"]
1009
- class_name = architectures[0]
1010
-
1011
- module = __import__("transformers")
1012
- transfomer_class = getattr(module, class_name)
1013
- from transformers import AutoConfig
1014
-
1015
- import tempfile
1016
- with tempfile.NamedTemporaryFile("w", delete = False, encoding ="utf-8") as fp:
1017
- fp.write(json.dumps(transformer_config))
1018
- fp.close()
1019
- config_obj = AutoConfig.from_pretrained(fp.name)
1020
- os.remove(fp.name)
1021
-
1022
- #needed to keep inits of non persistent buffers
1023
- with init_empty_weights():
1024
- model = transfomer_class(config_obj)
1025
-
1026
- model = model.base_model
1027
-
1028
- elif "_class_name" in transformer_config:
1029
- class_name = transformer_config["_class_name"]
1030
-
1031
- module = __import__("diffusers")
1032
- transfomer_class = getattr(module, class_name)
1033
-
1034
- with init_empty_weights():
1035
- model = transfomer_class.from_config(transformer_config)
1036
-
1037
-
1038
- torch.set_default_device('cpu')
1039
-
1040
- model._config = transformer_config
1041
-
1042
- load_model_data(model,model_path, do_quantize = do_quantize, quantization_type = quantization_type, pinToMemory= pinToMemory, partialPinning= partialPinning, verboseLevel=verbose_level )
1043
-
1044
- return model
1045
-
1046
-
1047
-
1048
- def load_model_data(model, file_path: str, do_quantize = False, quantization_type = qint8, pinToMemory = False, partialPinning = False, verboseLevel = 1):
1049
- """
1050
- Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
1051
- """
1052
-
1053
- file_path = _get_model(file_path)
1054
- safetensors2.verboseLevel = verboseLevel
1055
- model = _remove_model_wrapper(model)
1056
-
1057
- # if pinToMemory and do_quantize:
1058
- # raise Exception("Pinning and Quantization can not be used at the same time")
1059
-
1060
- if not (".safetensors" in file_path or ".sft" in file_path):
1061
- if pinToMemory:
1062
- raise Exception("Pinning to memory while loading only supported for safe tensors files")
1063
- state_dict = torch.load(file_path, weights_only=True)
1064
- if "module" in state_dict:
1065
- state_dict = state_dict["module"]
1066
- else:
1067
- state_dict, metadata = _safetensors_load_file(file_path)
1068
-
1069
-
1070
- # if pinToMemory:
1071
- # _pin_to_memory_sd(model,state_dict, file_path, partialPinning = partialPinning, perc_reserved_mem_max = perc_reserved_mem_max, verboseLevel = verboseLevel)
1072
-
1073
- # with safetensors2.safe_open(file_path) as f:
1074
- # metadata = f.metadata()
1075
-
1076
-
1077
- if metadata is None:
1078
- quantization_map = None
1079
- else:
1080
- quantization_map = metadata.get("quantization_map", None)
1081
- config = metadata.get("config", None)
1082
- if config is not None:
1083
- model._config = config
1084
-
1085
-
1086
-
1087
- if quantization_map is None:
1088
- pos = str.rfind(file_path, ".")
1089
- if pos > 0:
1090
- quantization_map_path = file_path[:pos]
1091
- quantization_map_path += "_map.json"
1092
-
1093
- if os.path.isfile(quantization_map_path):
1094
- with open(quantization_map_path, 'r') as f:
1095
- quantization_map = json.load(f)
1096
-
1097
-
1098
-
1099
- if quantization_map is None :
1100
- if "quanto" in file_path and not do_quantize:
1101
- print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
1102
- else:
1103
- _requantize(model, state_dict, quantization_map)
1104
-
1105
- missing_keys , unexpected_keys = model.load_state_dict(state_dict, strict = quantization_map is None, assign = True )
1106
- del state_dict
1107
-
1108
- if do_quantize:
1109
- if quantization_map is None:
1110
- if _quantize(model, quantization_type, verboseLevel=verboseLevel, model_id=file_path):
1111
- quantization_map = model._quanto_map
1112
- else:
1113
- if verboseLevel >=1:
1114
- print("Model already quantized")
1115
-
1116
- if pinToMemory:
1117
- _pin_to_memory(model, file_path, partialPinning = partialPinning, verboseLevel = verboseLevel)
1118
-
1119
- return
1120
-
1121
- def save_model(model, file_path, do_quantize = False, quantization_type = qint8, verboseLevel = 1 ):
1122
- """save the weights of a model and quantize them if requested
1123
- These weights can be loaded again using 'load_model_data'
1124
- """
1125
-
1126
- config = None
1127
-
1128
- if hasattr(model, "_config"):
1129
- config = model._config
1130
- elif hasattr(model, "config"):
1131
- config_fullpath = None
1132
- config_obj = getattr(model,"config")
1133
- config_path = getattr(config_obj,"_name_or_path", None)
1134
- if config_path != None:
1135
- config_fullpath = os.path.join(config_path, "config.json")
1136
- if not os.path.isfile(config_fullpath):
1137
- config_fullpath = None
1138
- if config_fullpath is None:
1139
- config_fullpath = os.path.join(os.path.dirname(file_path), "config.json")
1140
- if os.path.isfile(config_fullpath):
1141
- with open(config_fullpath, "r", encoding="utf-8") as reader:
1142
- text = reader.read()
1143
- config= json.loads(text)
1144
-
1145
- if do_quantize:
1146
- _quantize(model, weights=quantization_type, model_id=file_path)
1147
-
1148
- quantization_map = getattr(model, "_quanto_map", None)
1149
-
1150
- if verboseLevel >=1:
1151
- print(f"Saving file '{file_path}")
1152
- safetensors2.torch_write_file(model.state_dict(), file_path , quantization_map = quantization_map, config = config)
1153
- if verboseLevel >=1:
1154
- print(f"File '{file_path} saved")
1155
-
1156
-
1157
-
1158
-
1159
- def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = True, extraModelsToQuantize = None, budgets= 0, asyncTransfers = True, compile = False, perc_reserved_mem_max = 0, verboseLevel = 1):
1160
- """Hook to a pipeline or a group of modules in order to reduce their VRAM requirements:
1161
- pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
1162
- quantizeTransformer: set True by default will quantize on the fly the video / image model
1163
- pinnedMemory: move models in reserved memor. This allows very fast performance but requires 50% extra RAM (usually >=64 GB)
1164
- extraModelsToQuantize: a list of models to be also quantized on the fly (e.g the text_encoder), useful to reduce bith RAM and VRAM consumption
1165
- budgets: 0 by default (unlimited). If non 0, it corresponds to the maximum size in MB that every model will occupy at any moment
1166
- (in fact the real usage is twice this number). It is very efficient to reduce VRAM consumption but this feature may be very slow
1167
- if pinnedMemory is not enabled
1168
- """
1169
- self = offload()
1170
- self.verboseLevel = verboseLevel
1171
- safetensors2.verboseLevel = verboseLevel
1172
- self.modules_data = {}
1173
- model_budgets = {}
1174
-
1175
- windows_os = os.name == 'nt'
1176
- global total_pinned_bytes
1177
-
1178
-
1179
- budget = 0
1180
- if not budgets is None:
1181
- if isinstance(budgets , dict):
1182
- model_budgets = budgets
1183
- else:
1184
- budget = int(budgets) * ONE_MB
1185
-
1186
- # if (budgets!= None or budget >0) :
1187
- # self.async_transfers = True
1188
- self.async_transfers = asyncTransfers
1189
-
1190
-
1191
-
1192
- torch.set_default_device('cpu')
1193
-
1194
- if hasattr(pipe_or_dict_of_modules, "components"):
1195
- # create a fake Accelerate parameter so that lora loading doesn't change the device
1196
- pipe_or_dict_of_modules.hf_device_map = torch.device("cuda")
1197
- pipe_or_dict_of_modules= pipe_or_dict_of_modules.components
1198
-
1199
-
1200
- models = {k: _remove_model_wrapper(v) for k, v in pipe_or_dict_of_modules.items() if isinstance(v, torch.nn.Module)}
1201
-
1202
-
1203
-
1204
- _welcome()
1205
-
1206
- self.models = models
1207
-
1208
- extraModelsToQuantize = extraModelsToQuantize if extraModelsToQuantize is not None else []
1209
- if not isinstance(extraModelsToQuantize, list):
1210
- extraModelsToQuantize= [extraModelsToQuantize]
1211
- if quantizeTransformer:
1212
- extraModelsToQuantize.append("transformer")
1213
- models_to_quantize = extraModelsToQuantize
1214
-
1215
- modelsToPin = []
1216
- pinAllModels = False
1217
- if isinstance(pinnedMemory, bool):
1218
- pinAllModels = pinnedMemory
1219
- elif isinstance(pinnedMemory, list):
1220
- modelsToPin = pinnedMemory
1221
- else:
1222
- modelsToPin = [pinnedMemory]
1223
-
1224
- modelsToCompile = []
1225
- compileAllModels = False
1226
- if isinstance(compile, bool):
1227
- compileAllModels = compile
1228
- elif isinstance(compile, list):
1229
- modelsToCompile = compile
1230
- else:
1231
- modelsToCompile = [compile]
1232
-
1233
- self.anyCompiledModule = compileAllModels or len(modelsToCompile)>0
1234
- if self.anyCompiledModule:
1235
- torch._dynamo.config.cache_size_limit = 10000
1236
-
1237
- max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
1238
-
1239
- estimatesBytesToPin = 0
1240
-
1241
- for model_id in models:
1242
- current_model: torch.nn.Module = models[model_id]
1243
- # make sure that no RAM or GPU memory is not allocated for gradiant / training
1244
- current_model.to("cpu").eval()
1245
-
1246
- # if the model has just been quantized so there is no need to quantize it again
1247
- if model_id in models_to_quantize:
1248
- _quantize(current_model, weights=qint8, verboseLevel = self.verboseLevel, model_id=model_id)
1249
-
1250
- modelPinned = (pinAllModels or model_id in modelsToPin) and not hasattr(current_model,"_already_pinned")
1251
-
1252
- current_model_size = 0
1253
- # load all the remaining unread lazy safetensors in RAM to free open cache files
1254
- for p in current_model.parameters():
1255
- if isinstance(p, QTensor):
1256
- # # fix quanto bug (seems to have been fixed)
1257
- # if not modelPinned and p._scale.dtype == torch.float32:
1258
- # p._scale = p._scale.to(torch.bfloat16)
1259
- current_model_size += torch.numel(p._scale) * p._scale.element_size()
1260
- current_model_size += torch.numel(p._data) * p._data.element_size()
1261
- else:
1262
- if p.data.dtype == torch.float32:
1263
- # convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
1264
- p.data = p.data.to(torch.bfloat16)
1265
- current_model_size += torch.numel(p.data) * p.data.element_size()
1266
-
1267
- for b in current_model.buffers():
1268
- if b.data.dtype == torch.float32:
1269
- # convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
1270
- b.data = b.data.to(torch.bfloat16)
1271
- current_model_size += torch.numel(b.data) * b.data.element_size()
1272
-
1273
- if modelPinned:
1274
- estimatesBytesToPin += current_model_size
1275
-
1276
-
1277
- model_budget = model_budgets[model_id] * ONE_MB if model_id in model_budgets else budget
1278
-
1279
- if model_budget > 0 and model_budget > current_model_size:
1280
- model_budget = 0
1281
-
1282
- model_budgets[model_id] = model_budget
1283
-
1284
- partialPinning = False
1285
-
1286
- if estimatesBytesToPin > 0 and estimatesBytesToPin >= (max_reservable_memory - total_pinned_bytes):
1287
- if self.verboseLevel >=1:
1288
- print(f"Switching to partial pinning since full requirements for pinned models is {estimatesBytesToPin/ONE_MB:0.1f} MB while estimated reservable RAM is {max_reservable_memory/ONE_MB:0.1f} MB" )
1289
- partialPinning = True
1290
-
1291
- # Hook forward methods of modules
1292
- for model_id in models:
1293
- current_model: torch.nn.Module = models[model_id]
1294
- current_budget = model_budgets[model_id]
1295
- current_size = 0
1296
- cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1297
- self.loaded_blocks[model_id] = None
1298
- towers_names, towers_modules = _detect_main_towers(current_model)
1299
- towers_names = [n +"." for n in towers_names]
1300
- if self.verboseLevel>=2 and len(towers_names)>0:
1301
- print(f"Potential iterative blocks found in model '{model_id}':{towers_names}")
1302
- # compile main iterative modules stacks ("towers")
1303
- if compileAllModels or model_id in modelsToCompile :
1304
- #torch.compiler.reset()
1305
- if self.verboseLevel>=1:
1306
- print(f"Pytorch compilation of model '{model_id}' is scheduled.")
1307
- for tower in towers_modules:
1308
- for submodel in tower:
1309
- submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode="default" ) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
1310
-
1311
-
1312
- for submodule_name, submodule in current_model.named_modules():
1313
- # create a fake 'accelerate' parameter so that the _execution_device property returns always "cuda"
1314
- # (it is queried in many pipelines even if offloading is not properly implemented)
1315
- if not hasattr(submodule, "_hf_hook"):
1316
- setattr(submodule, "_hf_hook", HfHook())
1317
-
1318
- if submodule_name=='':
1319
- continue
1320
- newListItem = False
1321
- if current_budget > 0:
1322
- if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): #
1323
- if cur_blocks_prefix == None:
1324
- cur_blocks_prefix = submodule_name + "."
1325
- else:
1326
- #if cur_blocks_prefix != submodule_name[:len(cur_blocks_prefix)]:
1327
- if not submodule_name.startswith(cur_blocks_prefix):
1328
- cur_blocks_prefix = submodule_name + "."
1329
- cur_blocks_name,cur_blocks_seq = None, -1
1330
- else:
1331
-
1332
- if cur_blocks_prefix is not None:
1333
- if submodule_name.startswith(cur_blocks_prefix):
1334
- num = int(submodule_name[len(cur_blocks_prefix):].split(".")[0])
1335
- newListItem= num != cur_blocks_seq
1336
- if num != cur_blocks_seq and (cur_blocks_name == None or current_size > current_budget):
1337
- prev_blocks_name = cur_blocks_name
1338
- cur_blocks_name = cur_blocks_prefix + str(num)
1339
- # print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
1340
- cur_blocks_seq = num
1341
- else:
1342
- cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1343
-
1344
- if hasattr(submodule, "forward"):
1345
- submodule_method = getattr(submodule, "forward")
1346
- if callable(submodule_method):
1347
- if len(submodule_name.split("."))==1:
1348
- self.hook_change_module(submodule, current_model, model_id, submodule_name, submodule_method)
1349
- elif newListItem:
1350
- self.hook_load_data_if_needed(submodule, model_id, cur_blocks_name, context = submodule_name )
1351
- else:
1352
- self.hook_check_empty_cache_needed(submodule, model_id, cur_blocks_name, submodule_method, context = submodule_name )
1353
-
1354
-
1355
- current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
1356
-
1357
-
1358
- parameters_data = {}
1359
- if pinAllModels or model_id in modelsToPin:
1360
- if hasattr(current_model,"_already_pinned"):
1361
- if self.verboseLevel >=1:
1362
- print(f"Model '{model_id}' already pinned to reserved memory")
1363
- else:
1364
- _pin_to_memory(current_model, model_id, partialPinning= partialPinning, perc_reserved_mem_max=perc_reserved_mem_max, verboseLevel=verboseLevel)
1365
-
1366
-
1367
- for p in current_model.parameters():
1368
- parameters_data[p] = [p._data, p._scale] if isinstance(p, QTensor) else p.data
1369
-
1370
- buffers_data = {b: b.data for b in current_model.buffers()}
1371
- parameters_data.update(buffers_data)
1372
- self.modules_data[model_id]=parameters_data
1373
-
1374
- if self.verboseLevel >=2:
1375
- for n,b in self.blocks_of_modules_sizes.items():
1376
- print(f"Size of submodel '{n}': {b/ONE_MB:.1f} MB")
1377
-
1378
- torch.set_default_device('cuda')
1379
- torch.cuda.empty_cache()
1380
- gc.collect()
1381
-
1382
- return self
1383
-
1384
-
1385
- def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.VerylowRAM_LowVRAM_Slowest , verboseLevel = 1, **overrideKwargs):
1386
- """Apply a configuration profile that depends on your hardware:
1387
- pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
1388
- profile_name : num of the profile:
1389
- HighRAM_HighVRAM_Fastest (=1): at least 48 GB of RAM and 24 GB of VRAM : the fastest well suited for a RTX 3090 / RTX 4090
1390
- HighRAM_LowVRAM_Fast (=2): at least 48 GB of RAM and 12 GB of VRAM : a bit slower, better suited for RTX 3070/3080/4070/4080
1391
- or for RTX 3090 / RTX 4090 with large pictures batches or long videos
1392
- LowRAM_HighVRAM_Medium (=3): at least 32 GB of RAM and 24 GB of VRAM : so so speed but adapted for RTX 3090 / RTX 4090 with limited RAM
1393
- LowRAM_LowVRAM_Slow (=4): at least 32 GB of RAM and 12 GB of VRAM : if have little VRAM or generate longer videos
1394
- VerylowRAM_LowVRAM_Slowest (=5): at least 24 GB of RAM and 10 GB of VRAM : if you don't have much it won't be fast but maybe it will work
1395
- overrideKwargs: every parameter accepted by Offload.All can be added here to override the profile choice
1396
- For instance set quantizeTransformer = False to disable transformer quantization which is by default in every profile
1397
- """
1398
-
1399
- _welcome()
1400
-
1401
- modules = pipe_or_dict_of_modules
1402
-
1403
- if hasattr(modules, "components"):
1404
- modules= modules.components
1405
-
1406
- modules = {k: _remove_model_wrapper(v) for k, v in modules.items() if isinstance(v, torch.nn.Module)}
1407
- module_names = {k: v.__module__.lower() for k, v in modules.items() }
1408
-
1409
- default_extraModelsToQuantize = []
1410
- quantizeTransformer = True
1411
-
1412
- models_to_scan = ("text_encoder", "text_encoder_2")
1413
- candidates_to_quantize = ("t5", "llama", "llm")
1414
- for model_id in models_to_scan:
1415
- name = module_names[model_id]
1416
- for candidate in candidates_to_quantize:
1417
- if candidate in name:
1418
- default_extraModelsToQuantize.append(model_id)
1419
- break
1420
-
1421
-
1422
- # transformer (video or image generator) should be as small as possible not to occupy space that could be used by actual image data
1423
- # on the other hand the text encoder should be quite large (as long as it fits in 10 GB of VRAM) to reduce sequence offloading
1424
-
1425
- default_budgets = { "transformer" : 600 , "text_encoder": 3000, "text_encoder_2": 3000 }
1426
- extraModelsToQuantize = None
1427
-
1428
- if profile_no == profile_type.HighRAM_HighVRAM_Fastest:
1429
- pinnedMemory= True
1430
- budgets = None
1431
- info = "You have chosen a Very Fast profile that requires at least 48 GB of RAM and 24 GB of VRAM."
1432
- elif profile_no == profile_type.HighRAM_LowVRAM_Fast:
1433
- pinnedMemory= True
1434
- budgets = default_budgets
1435
- info = "You have chosen a Fast profile that requires at least 48 GB of RAM and 12 GB of VRAM."
1436
- elif profile_no == profile_type.LowRAM_HighVRAM_Medium:
1437
- pinnedMemory= "transformer"
1438
- extraModelsToQuantize = default_extraModelsToQuantize
1439
- info = "You have chosen a Medium speed profile that requires at least 32 GB of RAM and 24 GB of VRAM."
1440
- elif profile_no == profile_type.LowRAM_LowVRAM_Slow:
1441
- pinnedMemory= "transformer"
1442
- extraModelsToQuantize = default_extraModelsToQuantize
1443
- budgets=default_budgets
1444
- asyncTransfers = True
1445
- info = "You have chosen the Slow profile that requires at least 32 GB of RAM and 12 GB of VRAM."
1446
- elif profile_no == profile_type.VerylowRAM_LowVRAM_Slowest:
1447
- pinnedMemory= False
1448
- extraModelsToQuantize = default_extraModelsToQuantize
1449
- budgets=default_budgets
1450
- budgets["transformer"] = 400
1451
- asyncTransfers = False
1452
- info = "You have chosen the Slowest profile that requires at least 24 GB of RAM and 10 GB of VRAM."
1453
- else:
1454
- raise Exception("Unknown profile")
1455
- CrLf = '\r\n'
1456
- kwargs = { "pinnedMemory": pinnedMemory, "extraModelsToQuantize" : extraModelsToQuantize, "budgets": budgets, "asyncTransfers" : asyncTransfers, "quantizeTransformer": quantizeTransformer }
1457
-
1458
- if verboseLevel>=2:
1459
- info = info + CrLf + f"Profile '{profile_type.tostr(profile_no)}' sets the following options:"
1460
- for k,v in kwargs.items():
1461
- if k in overrideKwargs:
1462
- info = info + CrLf + f"- '{k}': '{kwargs[k]}' overriden with value '{overrideKwargs[k]}'"
1463
- else:
1464
- info = info + CrLf + f"- '{k}': '{kwargs[k]}'"
1465
-
1466
- for k,v in overrideKwargs.items():
1467
- kwargs[k] = overrideKwargs[k]
1468
-
1469
- if info:
1470
- print(info)
1471
-
1472
- return all(pipe_or_dict_of_modules, verboseLevel = verboseLevel, **kwargs)
1
+ # ------------------ Memory Management 3.0 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
+ #
3
+ # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
+ # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
5
+ # times in a pipe (eg VAE).
6
+ #
7
+ # Requirements:
8
+ # - VRAM: minimum 12 GB, recommended 24 GB (RTX 3090/ RTX 4090)
9
+ # - RAM: minimum 24 GB, recommended 48 - 64 GB
10
+ #
11
+ # It is almost plug and play and just needs to be invoked from the main app just after the model pipeline has been created.
12
+ # Make sure that the pipeline explictly loads the models in the CPU device
13
+ # for instance: pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cpu")
14
+ # For a quick setup, you may want to choose between 5 profiles depending on your hardware, for instance:
15
+ # from mmgp import offload, profile_type
16
+ # offload.profile(pipe, profile_type.HighRAM_LowVRAM_Fast)
17
+ # Alternatively you may want to your own parameters, for instance:
18
+ # from mmgp import offload
19
+ # offload.all(pipe, pinToMemory=true, extraModelsToQuantize = ["text_encoder_2"] )
20
+ # The 'transformer' model that contains usually the video or image generator is quantized on the fly by default to 8 bits so that it can fit into 24 GB of VRAM.
21
+ # You can prevent the transformer quantization by adding the parameter quantizeTransformer = False
22
+ # If you want to save time on disk and reduce the loading time, you may want to load directly a prequantized model. In that case you need to set the option quantizeTransformer to False to turn off on the fly quantization.
23
+ # You can specify a list of additional models string ids to quantize (for instance the text_encoder) using the optional argument extraModelsToQuantize. This may be useful if you have less than 48 GB of RAM.
24
+ # Note that there is little advantage on the GPU / VRAM side to quantize text encoders as their inputs are usually quite light.
25
+ # Conversely if you have more than 48GB RAM you may want to enable RAM pinning with the option pinnedMemory = True. You will get in return super fast loading / unloading of models
26
+ # (this can save significant time if the same pipeline is run multiple times in a row)
27
+ #
28
+ # Sometime there isn't an explicit pipe object as each submodel is loaded separately in the main app. If this is the case, you need to create a dictionary that manually maps all the models.
29
+ #
30
+ # For instance :
31
+ # for flux derived models: pipe = { "text_encoder": clip, "text_encoder_2": t5, "transformer": model, "vae":ae }
32
+ # for mochi: pipe = { "text_encoder": self.text_encoder, "transformer": self.dit, "vae":self.decoder }
33
+ #
34
+ # Please note that there should be always one model whose Id is 'transformer'. It corresponds to the main image / video model which usually needs to be quantized (this is done on the fly by default when loading the model)
35
+ #
36
+ # Becareful, lots of models use the T5 XXL as a text encoder. However, quite often their corresponding pipeline configurations point at the official Google T5 XXL repository
37
+ # where there is a huge 40GB model to download and load. It is cumbersorme as it is a 32 bits model and contains the decoder part of T5 that is not used.
38
+ # I suggest you use instead one of the 16 bits encoder only version available around, for instance:
39
+ # text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2", torch_dtype=torch.float16)
40
+ #
41
+ # Sometime just providing the pipe won't be sufficient as you will need to change the content of the core model:
42
+ # - For instance you may need to disable an existing CPU offload logic that already exists (such as manual calls to move tensors between cuda and the cpu)
43
+ # - mmpg to tries to fake the device as being "cuda" but sometimes some code won't be fooled and it will create tensors in the cpu device and this may cause some issues.
44
+ #
45
+ # You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
46
+ #
47
+ # Thanks to
48
+ # ---------
49
+ # Huggingface / accelerate for the hooking examples
50
+ # Huggingface / quanto for their very useful quantizer
51
+ # gau-nernst for his Pinnig RAM samples
52
+
53
+
54
+ #
55
+
56
+ import torch
57
+ import gc
58
+ import time
59
+ import functools
60
+ import sys
61
+ import os
62
+ import json
63
+ import psutil
64
+ from mmgp import safetensors2
65
+ from mmgp import profile_type
66
+
67
+ from optimum.quanto import freeze, qfloat8, qint8, quantize, QModuleMixin, QTensor, WeightQBytesTensor, quantize_module
68
+
69
+
70
+
71
+
72
+ mmm = safetensors2.mmm
73
+
74
+ default_verboseLevel = 1
75
+
76
+ ONE_MB = 1048576
77
+ sizeofbfloat16 = torch.bfloat16.itemsize
78
+ sizeofint8 = torch.int8.itemsize
79
+ total_pinned_bytes = 0
80
+ physical_memory= psutil.virtual_memory().total
81
+
82
+ HEADER = '\033[95m'
83
+ ENDC = '\033[0m'
84
+ BOLD ='\033[1m'
85
+ UNBOLD ='\033[0m'
86
+
87
+ cotenants_map = {
88
+ "text_encoder": ["vae", "text_encoder_2"],
89
+ "text_encoder_2": ["vae", "text_encoder"],
90
+ }
91
+
92
+ class clock:
93
+ def __init__(self):
94
+ self.start_time = 0
95
+ self.end_time = 0
96
+
97
+ @classmethod
98
+ def start(cls):
99
+ self = cls()
100
+ self.start_time =time.time()
101
+ return self
102
+
103
+ def stop(self):
104
+ self.stop_time =time.time()
105
+
106
+ def time_gap(self):
107
+ return self.stop_time - self.start_time
108
+
109
+ def format_time_gap(self):
110
+ return f"{self.stop_time - self.start_time:.2f}s"
111
+
112
+
113
+
114
+ # useful functions to move a group of tensors (to design custom offload patches)
115
+ def move_tensors(obj, device):
116
+ if torch.is_tensor(obj):
117
+ return obj.to(device)
118
+ elif isinstance(obj, dict):
119
+ _dict = {}
120
+ for k, v in obj.items():
121
+ _dict[k] = move_tensors(v, device)
122
+ return _dict
123
+ elif isinstance(obj, list):
124
+ _list = []
125
+ for v in obj:
126
+ _list.append(move_tensors(v, device))
127
+ return _list
128
+ else:
129
+ raise TypeError("Tensor or list / dict of tensors expected")
130
+
131
+ def _compute_verbose_level(level):
132
+ if level <0:
133
+ level = safetensors2.verboseLevel = default_verboseLevel
134
+ safetensors2.verboseLevel = level
135
+ return level
136
+
137
+ def _get_max_reservable_memory(perc_reserved_mem_max):
138
+ if perc_reserved_mem_max<=0:
139
+ perc_reserved_mem_max = 0.40 if os.name == 'nt' else 0.5
140
+ return perc_reserved_mem_max * physical_memory
141
+
142
+ def _detect_main_towers(model, verboseLevel=1):
143
+ cur_blocks_prefix = None
144
+ towers_modules= []
145
+ towers_names= []
146
+
147
+ for submodule_name, submodule in model.named_modules():
148
+ if submodule_name=='':
149
+ continue
150
+
151
+ if isinstance(submodule, torch.nn.ModuleList):
152
+ newList =False
153
+ if cur_blocks_prefix == None:
154
+ cur_blocks_prefix = submodule_name + "."
155
+ newList = True
156
+ else:
157
+ if not submodule_name.startswith(cur_blocks_prefix):
158
+ cur_blocks_prefix = submodule_name + "."
159
+ newList = True
160
+
161
+ if newList and len(submodule)>=5:
162
+ towers_names.append(submodule_name)
163
+ towers_modules.append(submodule)
164
+
165
+ else:
166
+ if cur_blocks_prefix is not None:
167
+ if not submodule_name.startswith(cur_blocks_prefix):
168
+ cur_blocks_prefix = None
169
+
170
+ return towers_names, towers_modules
171
+
172
+
173
+
174
+ def _get_model(model_path):
175
+ if os.path.isfile(model_path):
176
+ return model_path
177
+
178
+ from pathlib import Path
179
+ _path = Path(model_path).parts
180
+ _filename = _path[-1]
181
+ _path = _path[:-1]
182
+ if len(_path)==1:
183
+ raise("file not found")
184
+ else:
185
+ from huggingface_hub import hf_hub_download #snapshot_download,
186
+ repoId= os.path.join(*_path[0:2] ).replace("\\", "/")
187
+
188
+ if len(_path) > 2:
189
+ _subfolder = os.path.join(*_path[2:] )
190
+ model_path = hf_hub_download(repo_id=repoId, filename=_filename, subfolder=_subfolder)
191
+ else:
192
+ model_path = hf_hub_download(repo_id=repoId, filename=_filename)
193
+
194
+ return model_path
195
+
196
+
197
+
198
+ def _remove_model_wrapper(model):
199
+ if not model._modules is None:
200
+ if len(model._modules)!=1:
201
+ return model
202
+ sub_module = model._modules[next(iter(model._modules))]
203
+ if hasattr(sub_module,"config") or hasattr(sub_module,"base_model"):
204
+ return sub_module
205
+ return model
206
+
207
+
208
+
209
+ def _move_to_pinned_tensor(source_tensor, big_tensor, offset, length):
210
+ dtype= source_tensor.dtype
211
+ shape = source_tensor.shape
212
+ if len(shape) == 0:
213
+ return source_tensor
214
+ else:
215
+ t = source_tensor.view(torch.uint8)
216
+ t = torch.reshape(t, (length,))
217
+ # magic swap !
218
+ big_tensor[offset: offset + length] = t
219
+ t = big_tensor[offset: offset + length]
220
+ t = t.view(dtype)
221
+ t = torch.reshape(t, shape)
222
+ assert t.is_pinned()
223
+ return t
224
+
225
+ def _safetensors_load_file(file_path):
226
+ from collections import OrderedDict
227
+ sd = OrderedDict()
228
+
229
+ with safetensors2.safe_open(file_path, framework="pt", device="cpu") as f:
230
+ for k in f.keys():
231
+ sd[k] = f.get_tensor(k)
232
+ metadata = f.metadata()
233
+
234
+ return sd, metadata
235
+
236
+ def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_max = 0, verboseLevel = 1):
237
+ if verboseLevel>=1 :
238
+ if partialPinning:
239
+ print(f"Partial pinning of data of '{model_id}' to reserved RAM")
240
+ else:
241
+ print(f"Pinning data of '{model_id}' to reserved RAM")
242
+
243
+ max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
244
+ if partialPinning:
245
+ towers_names, _ = _detect_main_towers(model)
246
+ towers_names = [n +"." for n in towers_names]
247
+
248
+ BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
249
+ current_big_tensor_size = 0
250
+ big_tensor_no = 0
251
+ big_tensors_sizes = []
252
+ tensor_map_indexes = []
253
+ total_tensor_bytes = 0
254
+
255
+ params_list = []
256
+ for k, sub_module in model.named_modules():
257
+ include = True
258
+ if partialPinning:
259
+ include = any(k.startswith(pre) for pre in towers_names) if partialPinning else True
260
+ if include:
261
+ params_list = params_list + list(sub_module.buffers(recurse=False)) + list(sub_module.parameters(recurse=False))
262
+
263
+ # print(f"num params to pin {model_id}: {len(params_list)}")
264
+ for p in params_list:
265
+ if isinstance(p, QTensor):
266
+ length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
267
+ else:
268
+ length = torch.numel(p.data) * p.data.element_size()
269
+
270
+ if current_big_tensor_size + length > BIG_TENSOR_MAX_SIZE:
271
+ big_tensors_sizes.append(current_big_tensor_size)
272
+ current_big_tensor_size = 0
273
+ big_tensor_no += 1
274
+ tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
275
+ current_big_tensor_size += length
276
+
277
+ total_tensor_bytes += length
278
+
279
+
280
+ big_tensors_sizes.append(current_big_tensor_size)
281
+
282
+ big_tensors = []
283
+ last_big_tensor = 0
284
+ total = 0
285
+
286
+
287
+
288
+ for size in big_tensors_sizes:
289
+ try:
290
+ current_big_tensor = torch.empty( size, dtype= torch.uint8, pin_memory=True, device="cpu")
291
+ big_tensors.append(current_big_tensor)
292
+ except:
293
+ print(f"Unable to pin more tensors for this model as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
294
+ break
295
+
296
+ last_big_tensor += 1
297
+ total += size
298
+
299
+
300
+ gc.collect()
301
+
302
+ tensor_no = 0
303
+ for p in params_list:
304
+ big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
305
+
306
+ if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
307
+ current_big_tensor = big_tensors[big_tensor_no]
308
+ if isinstance(p, QTensor):
309
+ length1 = torch.numel(p._data) * p._data.element_size()
310
+ p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
311
+ length2 = torch.numel(p._scale) * p._scale.element_size()
312
+ p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
313
+ else:
314
+ length = torch.numel(p.data) * p.data.element_size()
315
+ p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
316
+
317
+ tensor_no += 1
318
+ global total_pinned_bytes
319
+ total_pinned_bytes += total
320
+ gc.collect()
321
+
322
+ if verboseLevel >=1:
323
+ if total_tensor_bytes == total:
324
+ print(f"The whole model was pinned to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
325
+ else:
326
+ print(f"{total/ONE_MB:.2f} MB were pinned to reserved RAM out of {total_tensor_bytes/ONE_MB:.2f} MB")
327
+
328
+ model._already_pinned = True
329
+
330
+
331
+ return
332
+ welcome_displayed = False
333
+
334
+ def _welcome():
335
+ global welcome_displayed
336
+ if welcome_displayed:
337
+ return
338
+ welcome_displayed = True
339
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.0) by DeepBeepMeep ************{ENDC}{UNBOLD}")
340
+
341
+
342
+ # def _pin_to_memory_sd(model, sd, model_id, partialPinning = False, perc_reserved_mem_max = 0, verboseLevel = 1):
343
+ # if verboseLevel>=1 :
344
+ # if partialPinning:
345
+ # print(f"Partial pinning to reserved RAM of data of file '{model_id}' while loading it")
346
+ # else:
347
+ # print(f"Pinning data to reserved RAM of file '{model_id}' while loading it")
348
+
349
+ # max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
350
+ # if partialPinning:
351
+ # towers_names, _ = _detect_main_towers(model)
352
+ # towers_names = [n +"." for n in towers_names]
353
+
354
+ # BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
355
+ # current_big_tensor_size = 0
356
+ # big_tensor_no = 0
357
+ # big_tensors_sizes = []
358
+ # tensor_map_indexes = []
359
+ # total_tensor_bytes = 0
360
+
361
+ # for k,t in sd.items():
362
+ # include = True
363
+ # # if isinstance(p, QTensor):
364
+ # # length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
365
+ # # else:
366
+ # # length = torch.numel(p.data) * p.data.element_size()
367
+ # length = torch.numel(t) * t.data.element_size()
368
+
369
+ # if partialPinning:
370
+ # include = any(k.startswith(pre) for pre in towers_names) if partialPinning else True
371
+
372
+ # if include:
373
+ # if current_big_tensor_size + length > BIG_TENSOR_MAX_SIZE:
374
+ # big_tensors_sizes.append(current_big_tensor_size)
375
+ # current_big_tensor_size = 0
376
+ # big_tensor_no += 1
377
+ # tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
378
+ # current_big_tensor_size += length
379
+ # else:
380
+ # tensor_map_indexes.append((-1, 0, 0 ))
381
+ # total_tensor_bytes += length
382
+
383
+ # big_tensors_sizes.append(current_big_tensor_size)
384
+
385
+ # big_tensors = []
386
+ # last_big_tensor = 0
387
+ # total = 0
388
+
389
+
390
+ # for size in big_tensors_sizes:
391
+ # try:
392
+ # currrent_big_tensor = torch.empty( size, dtype= torch.uint8, pin_memory=True)
393
+ # big_tensors.append(currrent_big_tensor)
394
+ # except:
395
+ # print(f"Unable to pin more tensors for this model as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
396
+ # break
397
+
398
+ # last_big_tensor += 1
399
+ # total += size
400
+
401
+
402
+ # tensor_no = 0
403
+ # for k,t in sd.items():
404
+ # big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
405
+ # if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
406
+ # current_big_tensor = big_tensors[big_tensor_no]
407
+ # # if isinstance(p, QTensor):
408
+ # # length1 = torch.numel(p._data) * p._data.element_size()
409
+ # # p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
410
+ # # length2 = torch.numel(p._scale) * p._scale.element_size()
411
+ # # p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
412
+ # # else:
413
+ # # length = torch.numel(p.data) * p.data.element_size()
414
+ # # p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
415
+ # length = torch.numel(t) * t.data.element_size()
416
+ # t = _move_to_pinned_tensor(t, current_big_tensor, offset, length)
417
+ # sd[k] = t
418
+ # tensor_no += 1
419
+
420
+ # global total_pinned_bytes
421
+ # total_pinned_bytes += total
422
+
423
+ # if verboseLevel >=1:
424
+ # if total_tensor_bytes == total:
425
+ # print(f"The whole model was pinned to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
426
+ # else:
427
+ # print(f"{total/ONE_MB:.2f} MB were pinned to reserved RAM out of {total_tensor_bytes/ONE_MB:.2f} MB")
428
+
429
+ # model._already_pinned = True
430
+
431
+
432
+ # return
433
+
434
+ def _quantize_dirty_hack(model):
435
+ # dirty hack: add a hook on state_dict() to return a fake non quantized state_dict if called by Lora Diffusers initialization functions
436
+ setattr( model, "_real_state_dict", model.state_dict)
437
+ from collections import OrderedDict
438
+ import traceback
439
+
440
+ def state_dict_for_lora(self):
441
+ real_sd = self._real_state_dict()
442
+ fakeit = False
443
+ stack = traceback.extract_stack(f=None, limit=5)
444
+ for frame in stack:
445
+ if "_lora_" in frame.name:
446
+ fakeit = True
447
+ break
448
+
449
+ if not fakeit:
450
+ return real_sd
451
+ sd = OrderedDict()
452
+ for k in real_sd:
453
+ v = real_sd[k]
454
+ if k.endswith("._data"):
455
+ k = k[:len(k)-6]
456
+ sd[k] = v
457
+ return sd
458
+
459
+ setattr(model, "state_dict", functools.update_wrapper(functools.partial(state_dict_for_lora, model), model.state_dict) )
460
+
461
+ def _quantization_map(model):
462
+ from optimum.quanto import quantization_map
463
+ return quantization_map(model)
464
+
465
+ def _set_module_by_name(parent_module, name, child_module):
466
+ module_names = name.split(".")
467
+ if len(module_names) == 1:
468
+ setattr(parent_module, name, child_module)
469
+ else:
470
+ parent_module_name = name[: name.rindex(".")]
471
+ parent_module = parent_module.get_submodule(parent_module_name)
472
+ setattr(parent_module, module_names[-1], child_module)
473
+
474
+ def _quantize_submodule(
475
+ model: torch.nn.Module,
476
+ name: str,
477
+ module: torch.nn.Module,
478
+ weights = None,
479
+ activations = None,
480
+ optimizer = None,
481
+ ):
482
+
483
+ qmodule = quantize_module(module, weights=weights, activations=activations, optimizer=optimizer)
484
+ if qmodule is not None:
485
+ _set_module_by_name(model, name, qmodule)
486
+ qmodule.name = name
487
+ for name, param in module.named_parameters():
488
+ # Save device memory by clearing parameters
489
+ setattr(module, name, None)
490
+ del param
491
+
492
+ def _requantize(model: torch.nn.Module, state_dict: dict, quantization_map: dict):
493
+ # change dtype of current meta model parameters because 'requantize' won't update the dtype on non quantized parameters
494
+ for k, p in model.named_parameters():
495
+ if not k in quantization_map and k in state_dict:
496
+ p_in_file = state_dict[k]
497
+ if p.data.dtype != p_in_file.data.dtype:
498
+ p.data = p.data.to(p_in_file.data.dtype)
499
+
500
+ # rebuild quanto objects
501
+ for name, m in model.named_modules():
502
+ qconfig = quantization_map.get(name, None)
503
+ if qconfig is not None:
504
+ weights = qconfig["weights"]
505
+ if weights == "none":
506
+ weights = None
507
+ activations = qconfig["activations"]
508
+ if activations == "none":
509
+ activations = None
510
+ _quantize_submodule(model, name, m, weights=weights, activations=activations)
511
+
512
+ model._quanto_map = quantization_map
513
+
514
+ _quantize_dirty_hack(model)
515
+
516
+
517
+
518
+ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 1000000000, model_id = 'Unknown'):
519
+
520
+ def compute_submodule_size(submodule):
521
+ size = 0
522
+ for p in submodule.parameters(recurse=False):
523
+ size += torch.numel(p.data) * sizeofbfloat16
524
+
525
+ for p in submodule.buffers(recurse=False):
526
+ size += torch.numel(p.data) * sizeofbfloat16
527
+
528
+ return size
529
+
530
+ total_size =0
531
+ total_excluded = 0
532
+ exclude_list = []
533
+ submodule_size = 0
534
+ submodule_names = []
535
+ cur_blocks_prefix = None
536
+ prev_blocks_prefix = None
537
+
538
+ if hasattr(model_to_quantize, "_quanto_map"):
539
+ print(f"Model '{model_id}' is already quantized")
540
+ return False
541
+
542
+ print(f"Quantization of model '{model_id}' started")
543
+
544
+ for submodule_name, submodule in model_to_quantize.named_modules():
545
+ if isinstance(submodule, QModuleMixin):
546
+ if verboseLevel>=1:
547
+ print("No quantization to do as model is already quantized")
548
+ return False
549
+
550
+
551
+ if submodule_name=='':
552
+ continue
553
+
554
+
555
+ flush = False
556
+ if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
557
+ if cur_blocks_prefix == None:
558
+ cur_blocks_prefix = submodule_name + "."
559
+ flush = True
560
+ else:
561
+ #if cur_blocks_prefix != submodule_name[:len(cur_blocks_prefix)]:
562
+ if not submodule_name.startswith(cur_blocks_prefix):
563
+ cur_blocks_prefix = submodule_name + "."
564
+ flush = True
565
+ else:
566
+ if cur_blocks_prefix is not None:
567
+ #if not cur_blocks_prefix == submodule_name[0:len(cur_blocks_prefix)]:
568
+ if not submodule_name.startswith(cur_blocks_prefix):
569
+ cur_blocks_prefix = None
570
+ flush = True
571
+
572
+ if flush:
573
+ if submodule_size <= threshold:
574
+ exclude_list += submodule_names
575
+ if verboseLevel >=2:
576
+ print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
577
+ total_excluded += submodule_size
578
+
579
+ submodule_size = 0
580
+ submodule_names = []
581
+ prev_blocks_prefix = cur_blocks_prefix
582
+ size = compute_submodule_size(submodule)
583
+ submodule_size += size
584
+ total_size += size
585
+ submodule_names.append(submodule_name)
586
+
587
+ if submodule_size > 0 and submodule_size <= threshold:
588
+ exclude_list += submodule_names
589
+ if verboseLevel >=2:
590
+ print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
591
+ total_excluded += submodule_size
592
+
593
+ perc_excluded =total_excluded/ total_size if total_size >0 else 1
594
+ if verboseLevel >=2:
595
+ print(f"Total Excluded {total_excluded/ONE_MB:.1f} MB oF {total_size/ONE_MB:.1f} that is {perc_excluded*100:.2f}%")
596
+ if perc_excluded >= 0.10:
597
+ print(f"Too many many modules are excluded, there is something wrong with the selection, switch back to full quantization.")
598
+ exclude_list = None
599
+
600
+
601
+ #quantize(model_to_quantize,weights, exclude= exclude_list)
602
+ pass
603
+ for name, m in model_to_quantize.named_modules():
604
+ if exclude_list is None or not any( name == module_name for module_name in exclude_list):
605
+ _quantize_submodule(model_to_quantize, name, m, weights=weights, activations=None, optimizer=None)
606
+
607
+ # force read non quantized parameters so that their lazy tensors and corresponding mmap are released
608
+ # otherwise we may end up to keep in memory both the quantized and the non quantize model
609
+ for m in model_to_quantize.modules():
610
+ # do not read quantized weights (detected them directly or behind an adapter)
611
+ if isinstance(m, QModuleMixin) or hasattr(m, "base_layer") and isinstance(m.base_layer, QModuleMixin):
612
+ if hasattr(m, "bias") and m.bias is not None:
613
+ m.bias.data = m.bias.data + 0
614
+ else:
615
+ for n, p in m.named_parameters(recurse = False):
616
+ data = getattr(m, n)
617
+ setattr(m,n, torch.nn.Parameter(data + 0 ) )
618
+
619
+ for b in m.buffers(recurse = False):
620
+ b.data = b.data + 0
621
+
622
+
623
+ freeze(model_to_quantize)
624
+ torch.cuda.empty_cache()
625
+ gc.collect()
626
+ quantization_map = _quantization_map(model_to_quantize)
627
+ model_to_quantize._quanto_map = quantization_map
628
+
629
+ _quantize_dirty_hack(model_to_quantize)
630
+
631
+ print(f"Quantization of model '{model_id}' done")
632
+
633
+ return True
634
+
635
+ def get_model_name(model):
636
+ return model.name
637
+
638
+ class HfHook:
639
+ def __init__(self):
640
+ self.execution_device = "cuda"
641
+
642
+ def detach_hook(self, module):
643
+ pass
644
+
645
+ last_offload_obj = None
646
+ class offload:
647
+ def __init__(self):
648
+ self.active_models = []
649
+ self.active_models_ids = []
650
+ self.active_subcaches = {}
651
+ self.models = {}
652
+ self.verboseLevel = 0
653
+ self.blocks_of_modules = {}
654
+ self.blocks_of_modules_sizes = {}
655
+ self.anyCompiledModule = False
656
+ self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
657
+ self.last_reserved_mem_check =0
658
+ self.loaded_blocks = {}
659
+ self.prev_blocks_names = {}
660
+ self.next_blocks_names = {}
661
+ self.default_stream = torch.cuda.default_stream(torch.device("cuda")) # torch.cuda.current_stream()
662
+ self.transfer_stream = torch.cuda.Stream()
663
+ self.async_transfers = False
664
+ global last_offload_obj
665
+ last_offload_obj = self
666
+
667
+ def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name):
668
+
669
+ entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
670
+ if entry_name in self.blocks_of_modules:
671
+ blocks_params = self.blocks_of_modules[entry_name]
672
+ blocks_params_size = self.blocks_of_modules_sizes[entry_name]
673
+ else:
674
+ blocks_params = []
675
+ self.blocks_of_modules[entry_name] = blocks_params
676
+ blocks_params_size = 0
677
+ if blocks_name !=None:
678
+
679
+ prev_entry_name = None if prev_block_name == None else model_id + "/" + prev_block_name
680
+ self.prev_blocks_names[entry_name] = prev_entry_name
681
+ if not prev_block_name == None:
682
+ self.next_blocks_names[prev_entry_name] = entry_name
683
+
684
+
685
+ for k,p in submodule.named_parameters(recurse=False):
686
+ if isinstance(p, QTensor):
687
+ blocks_params.append( (submodule, k, p._data, p._scale) )
688
+ blocks_params_size += p._data.nbytes
689
+ blocks_params_size += p._scale.nbytes
690
+ else:
691
+ blocks_params.append( (submodule, k, p.data, None) )
692
+ blocks_params_size += p.data.nbytes
693
+
694
+ for k, p in submodule.named_buffers(recurse=False):
695
+ blocks_params.append( (submodule, k, p.data, None) )
696
+ blocks_params_size += p.data.nbytes
697
+
698
+
699
+ self.blocks_of_modules_sizes[entry_name] = blocks_params_size
700
+
701
+ return blocks_params_size
702
+
703
+
704
+ def can_model_be_cotenant(self, model_id):
705
+ potential_cotenants= cotenants_map.get(model_id, None)
706
+ if potential_cotenants is None:
707
+ return False
708
+ for existing_cotenant in self.active_models_ids:
709
+ if existing_cotenant not in potential_cotenants:
710
+ return False
711
+ return True
712
+
713
+ def gpu_load_blocks(self, model_id, blocks_name, async_load = False):
714
+ # cl = clock.start()
715
+
716
+ if blocks_name != None:
717
+ self.loaded_blocks[model_id] = blocks_name
718
+
719
+ entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
720
+
721
+ def cpu_to_gpu(stream_to_use, blocks_params, record_for_stream = None):
722
+ with torch.cuda.stream(stream_to_use):
723
+ for param in blocks_params:
724
+ parent_module, n, data, scale = param
725
+ p = getattr(parent_module, n)
726
+ if isinstance(p, QTensor):
727
+ q = WeightQBytesTensor.create(p.qtype, p.axis, p.size(), p.stride(), data.cuda(non_blocking=True), scale.cuda(non_blocking=True), activation_qtype=p.activation_qtype, requires_grad=p.requires_grad )
728
+ #q = p.to("cuda", non_blocking=True)
729
+ q = torch.nn.Parameter(q , requires_grad=False)
730
+ setattr(parent_module, n , q)
731
+ del p
732
+ else:
733
+ p.data = p.data.cuda(non_blocking=True)
734
+
735
+ if record_for_stream != None:
736
+ if isinstance(p, QTensor):
737
+ q._data.record_stream(record_for_stream)
738
+ q._scale.record_stream(record_for_stream)
739
+ else:
740
+ p.data.record_stream(record_for_stream)
741
+
742
+
743
+ if self.verboseLevel >=2:
744
+ model = self.models[model_id]
745
+ model_name = model._get_name()
746
+ print(f"Loading model {entry_name} ({model_name}) in GPU")
747
+
748
+
749
+ if self.async_transfers and blocks_name != None:
750
+ first = self.prev_blocks_names[entry_name] == None
751
+ next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None
752
+ if first:
753
+ cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name])
754
+ torch.cuda.synchronize()
755
+
756
+ if next_blocks_entry != None:
757
+ cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry]) #, self.default_stream
758
+
759
+ else:
760
+ cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name])
761
+ torch.cuda.synchronize()
762
+ # cl.stop()
763
+ # print(f"load time: {cl.format_time_gap()}")
764
+
765
+
766
+ def gpu_unload_blocks(self, model_id, blocks_name):
767
+ # cl = clock.start()
768
+ if blocks_name != None:
769
+ self.loaded_blocks[model_id] = None
770
+
771
+ blocks_name = model_id if blocks_name is None else model_id + "/" + blocks_name
772
+
773
+ if self.verboseLevel >=2:
774
+ model = self.models[model_id]
775
+ model_name = model._get_name()
776
+ print(f"Unloading model {blocks_name} ({model_name}) from GPU")
777
+
778
+ blocks_params = self.blocks_of_modules[blocks_name]
779
+
780
+ for param in blocks_params:
781
+ parent_module, n, data, scale = param
782
+ p = getattr(parent_module, n)
783
+ if isinstance(p, QTensor):
784
+ # need to change the parameter directly from the module as it can't be swapped in place due to a memory leak in the pytorch compiler
785
+ q = WeightQBytesTensor.create(p.qtype, p.axis, p.size(), p.stride(), data, scale, activation_qtype=p.activation_qtype, requires_grad=p.requires_grad )
786
+ q = torch.nn.Parameter(q , requires_grad=False)
787
+ setattr(parent_module, n , q)
788
+ del p
789
+ else:
790
+ p.data = data
791
+
792
+ # cl.stop()
793
+ # print(f"unload time: {cl.format_time_gap()}")
794
+
795
+
796
+ def gpu_load(self, model_id):
797
+ model = self.models[model_id]
798
+ self.active_models.append(model)
799
+ self.active_models_ids.append(model_id)
800
+
801
+ self.gpu_load_blocks(model_id, None)
802
+
803
+ # torch.cuda.current_stream().synchronize()
804
+
805
+ def unload_all(self):
806
+ for model_id in self.active_models_ids:
807
+ self.gpu_unload_blocks(model_id, None)
808
+ loaded_block = self.loaded_blocks[model_id]
809
+ if loaded_block != None:
810
+ self.gpu_unload_blocks(model_id, loaded_block)
811
+ self.loaded_blocks[model_id] = None
812
+
813
+ self.active_models = []
814
+ self.active_models_ids = []
815
+ self.active_subcaches = []
816
+ torch.cuda.empty_cache()
817
+ gc.collect()
818
+ self.last_reserved_mem_check = time.time()
819
+
820
+ def move_args_to_gpu(self, *args, **kwargs):
821
+ new_args= []
822
+ new_kwargs={}
823
+ for arg in args:
824
+ if torch.is_tensor(arg):
825
+ if arg.dtype == torch.float32:
826
+ arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
827
+ else:
828
+ arg = arg.cuda(non_blocking=True)
829
+ new_args.append(arg)
830
+
831
+ for k in kwargs:
832
+ arg = kwargs[k]
833
+ if torch.is_tensor(arg):
834
+ if arg.dtype == torch.float32:
835
+ arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
836
+ else:
837
+ arg = arg.cuda(non_blocking=True)
838
+ new_kwargs[k]= arg
839
+
840
+ return new_args, new_kwargs
841
+
842
+ def ready_to_check_mem(self):
843
+ if self.anyCompiledModule:
844
+ return
845
+ cur_clock = time.time()
846
+ # can't check at each call if we can empty the cuda cache as quering the reserved memory value is a time consuming operation
847
+ if (cur_clock - self.last_reserved_mem_check)<0.200:
848
+ return False
849
+ self.last_reserved_mem_check = cur_clock
850
+ return True
851
+
852
+
853
+ def empty_cache_if_needed(self):
854
+ mem_reserved = torch.cuda.memory_reserved()
855
+ mem_threshold = 0.9*self.device_mem_capacity
856
+ if mem_reserved >= mem_threshold:
857
+ mem_allocated = torch.cuda.memory_allocated()
858
+ if mem_allocated <= 0.70 * mem_reserved:
859
+ # print(f"Cuda empty cache triggered as Allocated Memory ({mem_allocated/1024000:0f} MB) is lot less than Cached Memory ({mem_reserved/1024000:0f} MB) ")
860
+ torch.cuda.empty_cache()
861
+ tm= time.time()
862
+ if self.verboseLevel >=2:
863
+ print(f"Empty Cuda cache at {tm}")
864
+ # print(f"New cached memory after purge is {torch.cuda.memory_reserved()/1024000:0f} MB) ")
865
+
866
+
867
+ def any_param_or_buffer(self, target_module: torch.nn.Module):
868
+
869
+ for _ in target_module.parameters(recurse= False):
870
+ return True
871
+
872
+ for _ in target_module.buffers(recurse= False):
873
+ return True
874
+
875
+ return False
876
+
877
+ def hook_load_data_if_needed(self, target_module, model_id,blocks_name, context):
878
+
879
+ @torch.compiler.disable()
880
+ def load_data_if_needed(module, *args, **kwargs):
881
+ some_context = context #for debugging
882
+ if blocks_name == None:
883
+ if self.ready_to_check_mem():
884
+ self.empty_cache_if_needed()
885
+ else:
886
+ loaded_block = self.loaded_blocks[model_id]
887
+ if (loaded_block == None or loaded_block != blocks_name) :
888
+ if loaded_block != None:
889
+ self.gpu_unload_blocks(model_id, loaded_block)
890
+ if self.ready_to_check_mem():
891
+ self.empty_cache_if_needed()
892
+ self.loaded_blocks[model_id] = blocks_name
893
+ self.gpu_load_blocks(model_id, blocks_name)
894
+
895
+ target_module.register_forward_pre_hook(load_data_if_needed)
896
+
897
+
898
+ def hook_check_empty_cache_needed(self, target_module, model_id,blocks_name, previous_method, context):
899
+
900
+ def check_empty_cuda_cache(module, *args, **kwargs):
901
+ # if self.ready_to_check_mem():
902
+ # self.empty_cache_if_needed()
903
+ if blocks_name == None:
904
+ if self.ready_to_check_mem():
905
+ self.empty_cache_if_needed()
906
+ else:
907
+ loaded_block = self.loaded_blocks[model_id]
908
+ if (loaded_block == None or loaded_block != blocks_name) :
909
+ if loaded_block != None:
910
+ self.gpu_unload_blocks(model_id, loaded_block)
911
+ if self.ready_to_check_mem():
912
+ self.empty_cache_if_needed()
913
+ self.loaded_blocks[model_id] = blocks_name
914
+ self.gpu_load_blocks(model_id, blocks_name)
915
+
916
+ return previous_method(*args, **kwargs)
917
+
918
+
919
+ if hasattr(target_module, "_mm_id"):
920
+ orig_model_id = getattr(target_module, "_mm_id")
921
+ if self.verboseLevel >=2:
922
+ print(f"Model '{model_id}' shares module '{target_module._get_name()}' with module '{orig_model_id}' ")
923
+ assert not self.any_param_or_buffer(target_module)
924
+
925
+ return
926
+ setattr(target_module, "_mm_id", model_id)
927
+ setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method) )
928
+
929
+
930
+ def hook_change_module(self, target_module, model, model_id, module_id, previous_method):
931
+ def check_change_module(module, *args, **kwargs):
932
+ performEmptyCacheTest = False
933
+ if not model_id in self.active_models_ids:
934
+ new_model_id = getattr(module, "_mm_id")
935
+ # do not always unload existing models if it is more efficient to keep in them in the GPU
936
+ # (e.g: small modules whose calls are text encoders)
937
+ if not self.can_model_be_cotenant(new_model_id) :
938
+ self.unload_all()
939
+ performEmptyCacheTest = False
940
+ self.gpu_load(new_model_id)
941
+ # transfer leftovers inputs that were incorrectly created in the RAM (mostly due to some .device tests that returned incorrectly "cpu")
942
+ args, kwargs = self.move_args_to_gpu(*args, **kwargs)
943
+ if performEmptyCacheTest:
944
+ self.empty_cache_if_needed()
945
+
946
+ return previous_method(*args, **kwargs)
947
+
948
+ if hasattr(target_module, "_mm_id"):
949
+ return
950
+ setattr(target_module, "_mm_id", model_id)
951
+
952
+ setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_change_module, target_module), previous_method) )
953
+
954
+ if not self.verboseLevel >=1:
955
+ return
956
+
957
+ if module_id == None or module_id =='':
958
+ model_name = model._get_name()
959
+ print(f"Hooked in model '{model_id}' ({model_name})")
960
+
961
+
962
+ # Not implemented yet, but why would one want to get rid of these features ?
963
+ # def unhook_module(module: torch.nn.Module):
964
+ # if not hasattr(module,"_mm_id"):
965
+ # return
966
+
967
+ # delattr(module, "_mm_id")
968
+
969
+ # def unhook_all(parent_module: torch.nn.Module):
970
+ # for module in parent_module.components.items():
971
+ # self.unhook_module(module)
972
+
973
+ def fast_load_transformers_model(model_path: str, do_quantize = False, quantization_type = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
974
+ """
975
+ quick version of .LoadfromPretrained of the transformers library
976
+ used to build a model and load the corresponding weights (quantized or not)
977
+ """
978
+
979
+
980
+ import os.path
981
+ from accelerate import init_empty_weights
982
+
983
+ if not (model_path.endswith(".sft") or model_path.endswith(".safetensors")):
984
+ raise Exception("full model path to file expected")
985
+
986
+ model_path = _get_model(model_path)
987
+ verboseLevel = _compute_verbose_level(verboseLevel)
988
+
989
+ with safetensors2.safe_open(model_path) as f:
990
+ metadata = f.metadata()
991
+
992
+ if metadata is None:
993
+ transformer_config = None
994
+ else:
995
+ transformer_config = metadata.get("config", None)
996
+
997
+ if transformer_config == None:
998
+ config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
999
+
1000
+ if not os.path.isfile(config_fullpath):
1001
+ raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
1002
+
1003
+ with open(config_fullpath, "r", encoding="utf-8") as reader:
1004
+ text = reader.read()
1005
+ transformer_config= json.loads(text)
1006
+
1007
+
1008
+ if "architectures" in transformer_config:
1009
+ architectures = transformer_config["architectures"]
1010
+ class_name = architectures[0]
1011
+
1012
+ module = __import__("transformers")
1013
+ transfomer_class = getattr(module, class_name)
1014
+ from transformers import AutoConfig
1015
+
1016
+ import tempfile
1017
+ with tempfile.NamedTemporaryFile("w", delete = False, encoding ="utf-8") as fp:
1018
+ fp.write(json.dumps(transformer_config))
1019
+ fp.close()
1020
+ config_obj = AutoConfig.from_pretrained(fp.name)
1021
+ os.remove(fp.name)
1022
+
1023
+ #needed to keep inits of non persistent buffers
1024
+ with init_empty_weights():
1025
+ model = transfomer_class(config_obj)
1026
+
1027
+ model = model.base_model
1028
+
1029
+ elif "_class_name" in transformer_config:
1030
+ class_name = transformer_config["_class_name"]
1031
+
1032
+ module = __import__("diffusers")
1033
+ transfomer_class = getattr(module, class_name)
1034
+
1035
+ with init_empty_weights():
1036
+ model = transfomer_class.from_config(transformer_config)
1037
+
1038
+
1039
+ torch.set_default_device('cpu')
1040
+
1041
+ model._config = transformer_config
1042
+
1043
+ load_model_data(model,model_path, do_quantize = do_quantize, quantization_type = quantization_type, pinToMemory= pinToMemory, partialPinning= partialPinning, verboseLevel=verboseLevel )
1044
+
1045
+ return model
1046
+
1047
+
1048
+
1049
+ def load_model_data(model, file_path: str, do_quantize = False, quantization_type = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
1050
+ """
1051
+ Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
1052
+ """
1053
+
1054
+ file_path = _get_model(file_path)
1055
+ verboseLevel = _compute_verbose_level(verboseLevel)
1056
+
1057
+ model = _remove_model_wrapper(model)
1058
+
1059
+ # if pinToMemory and do_quantize:
1060
+ # raise Exception("Pinning and Quantization can not be used at the same time")
1061
+
1062
+ if not (".safetensors" in file_path or ".sft" in file_path):
1063
+ if pinToMemory:
1064
+ raise Exception("Pinning to memory while loading only supported for safe tensors files")
1065
+ state_dict = torch.load(file_path, weights_only=True)
1066
+ if "module" in state_dict:
1067
+ state_dict = state_dict["module"]
1068
+ else:
1069
+ state_dict, metadata = _safetensors_load_file(file_path)
1070
+
1071
+
1072
+ # if pinToMemory:
1073
+ # _pin_to_memory_sd(model,state_dict, file_path, partialPinning = partialPinning, perc_reserved_mem_max = perc_reserved_mem_max, verboseLevel = verboseLevel)
1074
+
1075
+ # with safetensors2.safe_open(file_path) as f:
1076
+ # metadata = f.metadata()
1077
+
1078
+
1079
+ if metadata is None:
1080
+ quantization_map = None
1081
+ else:
1082
+ quantization_map = metadata.get("quantization_map", None)
1083
+ config = metadata.get("config", None)
1084
+ if config is not None:
1085
+ model._config = config
1086
+
1087
+
1088
+
1089
+ if quantization_map is None:
1090
+ pos = str.rfind(file_path, ".")
1091
+ if pos > 0:
1092
+ quantization_map_path = file_path[:pos]
1093
+ quantization_map_path += "_map.json"
1094
+
1095
+ if os.path.isfile(quantization_map_path):
1096
+ with open(quantization_map_path, 'r') as f:
1097
+ quantization_map = json.load(f)
1098
+
1099
+
1100
+
1101
+ if quantization_map is None :
1102
+ if "quanto" in file_path and not do_quantize:
1103
+ print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
1104
+ else:
1105
+ _requantize(model, state_dict, quantization_map)
1106
+
1107
+ missing_keys , unexpected_keys = model.load_state_dict(state_dict, strict = quantization_map is None, assign = True )
1108
+ del state_dict
1109
+
1110
+ if do_quantize:
1111
+ if quantization_map is None:
1112
+ if _quantize(model, quantization_type, verboseLevel=verboseLevel, model_id=file_path):
1113
+ quantization_map = model._quanto_map
1114
+ else:
1115
+ if verboseLevel >=1:
1116
+ print("Model already quantized")
1117
+
1118
+ if pinToMemory:
1119
+ _pin_to_memory(model, file_path, partialPinning = partialPinning, verboseLevel = verboseLevel)
1120
+
1121
+ return
1122
+
1123
+ def save_model(model, file_path, do_quantize = False, quantization_type = qint8, verboseLevel = -1 ):
1124
+ """save the weights of a model and quantize them if requested
1125
+ These weights can be loaded again using 'load_model_data'
1126
+ """
1127
+
1128
+ config = None
1129
+
1130
+ verboseLevel = _compute_verbose_level(verboseLevel)
1131
+
1132
+ if hasattr(model, "_config"):
1133
+ config = model._config
1134
+ elif hasattr(model, "config"):
1135
+ config_fullpath = None
1136
+ config_obj = getattr(model,"config")
1137
+ config_path = getattr(config_obj,"_name_or_path", None)
1138
+ if config_path != None:
1139
+ config_fullpath = os.path.join(config_path, "config.json")
1140
+ if not os.path.isfile(config_fullpath):
1141
+ config_fullpath = None
1142
+ if config_fullpath is None:
1143
+ config_fullpath = os.path.join(os.path.dirname(file_path), "config.json")
1144
+ if os.path.isfile(config_fullpath):
1145
+ with open(config_fullpath, "r", encoding="utf-8") as reader:
1146
+ text = reader.read()
1147
+ config= json.loads(text)
1148
+
1149
+ if do_quantize:
1150
+ _quantize(model, weights=quantization_type, model_id=file_path)
1151
+
1152
+ quantization_map = getattr(model, "_quanto_map", None)
1153
+
1154
+ if verboseLevel >=1:
1155
+ print(f"Saving file '{file_path}")
1156
+ safetensors2.torch_write_file(model.state_dict(), file_path , quantization_map = quantization_map, config = config)
1157
+ if verboseLevel >=1:
1158
+ print(f"File '{file_path} saved")
1159
+
1160
+
1161
+
1162
+
1163
+ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = True, extraModelsToQuantize = None, budgets= 0, asyncTransfers = True, compile = False, perc_reserved_mem_max = 0, verboseLevel = -1):
1164
+ """Hook to a pipeline or a group of modules in order to reduce their VRAM requirements:
1165
+ pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
1166
+ quantizeTransformer: set True by default will quantize on the fly the video / image model
1167
+ pinnedMemory: move models in reserved memor. This allows very fast performance but requires 50% extra RAM (usually >=64 GB)
1168
+ extraModelsToQuantize: a list of models to be also quantized on the fly (e.g the text_encoder), useful to reduce bith RAM and VRAM consumption
1169
+ budgets: 0 by default (unlimited). If non 0, it corresponds to the maximum size in MB that every model will occupy at any moment
1170
+ (in fact the real usage is twice this number). It is very efficient to reduce VRAM consumption but this feature may be very slow
1171
+ if pinnedMemory is not enabled
1172
+ """
1173
+ self = offload()
1174
+ self.verboseLevel = verboseLevel
1175
+ safetensors2.verboseLevel = verboseLevel
1176
+ self.modules_data = {}
1177
+ model_budgets = {}
1178
+
1179
+ windows_os = os.name == 'nt'
1180
+ global total_pinned_bytes
1181
+
1182
+
1183
+ budget = 0
1184
+ if not budgets is None:
1185
+ if isinstance(budgets , dict):
1186
+ model_budgets = budgets
1187
+ else:
1188
+ budget = int(budgets) * ONE_MB
1189
+
1190
+ # if (budgets!= None or budget >0) :
1191
+ # self.async_transfers = True
1192
+ self.async_transfers = asyncTransfers
1193
+
1194
+
1195
+
1196
+ torch.set_default_device('cpu')
1197
+
1198
+ if hasattr(pipe_or_dict_of_modules, "components"):
1199
+ # create a fake Accelerate parameter so that lora loading doesn't change the device
1200
+ pipe_or_dict_of_modules.hf_device_map = torch.device("cuda")
1201
+ pipe_or_dict_of_modules= pipe_or_dict_of_modules.components
1202
+
1203
+
1204
+ models = {k: _remove_model_wrapper(v) for k, v in pipe_or_dict_of_modules.items() if isinstance(v, torch.nn.Module)}
1205
+
1206
+
1207
+ verboseLevel = _compute_verbose_level(verboseLevel)
1208
+
1209
+ _welcome()
1210
+
1211
+ self.models = models
1212
+
1213
+ extraModelsToQuantize = extraModelsToQuantize if extraModelsToQuantize is not None else []
1214
+ if not isinstance(extraModelsToQuantize, list):
1215
+ extraModelsToQuantize= [extraModelsToQuantize]
1216
+ if quantizeTransformer:
1217
+ extraModelsToQuantize.append("transformer")
1218
+ models_to_quantize = extraModelsToQuantize
1219
+
1220
+ modelsToPin = []
1221
+ pinAllModels = False
1222
+ if isinstance(pinnedMemory, bool):
1223
+ pinAllModels = pinnedMemory
1224
+ elif isinstance(pinnedMemory, list):
1225
+ modelsToPin = pinnedMemory
1226
+ else:
1227
+ modelsToPin = [pinnedMemory]
1228
+
1229
+ modelsToCompile = []
1230
+ compileAllModels = False
1231
+ if isinstance(compile, bool):
1232
+ compileAllModels = compile
1233
+ elif isinstance(compile, list):
1234
+ modelsToCompile = compile
1235
+ else:
1236
+ modelsToCompile = [compile]
1237
+
1238
+ self.anyCompiledModule = compileAllModels or len(modelsToCompile)>0
1239
+ if self.anyCompiledModule:
1240
+ torch._dynamo.config.cache_size_limit = 10000
1241
+ # torch._logging.set_logs(recompiles=True)
1242
+ # torch._inductor.config.realize_opcount_threshold = 100 # workaround bug "AssertionError: increase TRITON_MAX_BLOCK['X'] to 4096."
1243
+
1244
+ max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
1245
+
1246
+ estimatesBytesToPin = 0
1247
+
1248
+ for model_id in models:
1249
+ current_model: torch.nn.Module = models[model_id]
1250
+ # make sure that no RAM or GPU memory is not allocated for gradiant / training
1251
+ current_model.to("cpu").eval()
1252
+
1253
+ # if the model has just been quantized so there is no need to quantize it again
1254
+ if model_id in models_to_quantize:
1255
+ _quantize(current_model, weights=qint8, verboseLevel = self.verboseLevel, model_id=model_id)
1256
+
1257
+ modelPinned = (pinAllModels or model_id in modelsToPin) and not hasattr(current_model,"_already_pinned")
1258
+
1259
+ current_model_size = 0
1260
+ # load all the remaining unread lazy safetensors in RAM to free open cache files
1261
+ for p in current_model.parameters():
1262
+ if isinstance(p, QTensor):
1263
+ # # fix quanto bug (seems to have been fixed)
1264
+ # if not modelPinned and p._scale.dtype == torch.float32:
1265
+ # p._scale = p._scale.to(torch.bfloat16)
1266
+ current_model_size += torch.numel(p._scale) * p._scale.element_size()
1267
+ current_model_size += torch.numel(p._data) * p._data.element_size()
1268
+ else:
1269
+ if p.data.dtype == torch.float32:
1270
+ # convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
1271
+ p.data = p.data.to(torch.bfloat16)
1272
+ current_model_size += torch.numel(p.data) * p.data.element_size()
1273
+
1274
+ for b in current_model.buffers():
1275
+ if b.data.dtype == torch.float32:
1276
+ # convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
1277
+ b.data = b.data.to(torch.bfloat16)
1278
+ current_model_size += torch.numel(b.data) * b.data.element_size()
1279
+
1280
+ if modelPinned:
1281
+ estimatesBytesToPin += current_model_size
1282
+
1283
+
1284
+ model_budget = model_budgets[model_id] * ONE_MB if model_id in model_budgets else budget
1285
+
1286
+ if model_budget > 0 and model_budget > current_model_size:
1287
+ model_budget = 0
1288
+
1289
+ model_budgets[model_id] = model_budget
1290
+
1291
+ partialPinning = False
1292
+
1293
+ if estimatesBytesToPin > 0 and estimatesBytesToPin >= (max_reservable_memory - total_pinned_bytes):
1294
+ if self.verboseLevel >=1:
1295
+ print(f"Switching to partial pinning since full requirements for pinned models is {estimatesBytesToPin/ONE_MB:0.1f} MB while estimated reservable RAM is {max_reservable_memory/ONE_MB:0.1f} MB" )
1296
+ partialPinning = True
1297
+
1298
+ # Hook forward methods of modules
1299
+ for model_id in models:
1300
+ current_model: torch.nn.Module = models[model_id]
1301
+ current_budget = model_budgets[model_id]
1302
+ current_size = 0
1303
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1304
+ self.loaded_blocks[model_id] = None
1305
+ towers_names, towers_modules = _detect_main_towers(current_model)
1306
+ towers_names = [n +"." for n in towers_names]
1307
+ if self.verboseLevel>=2 and len(towers_names)>0:
1308
+ print(f"Potential iterative blocks found in model '{model_id}':{towers_names}")
1309
+ # compile main iterative modules stacks ("towers")
1310
+ if compileAllModels or model_id in modelsToCompile :
1311
+ #torch.compiler.reset()
1312
+ if self.verboseLevel>=1:
1313
+ print(f"Pytorch compilation of model '{model_id}' is scheduled.")
1314
+ for tower in towers_modules:
1315
+ for submodel in tower:
1316
+ submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode="default" ) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
1317
+ #dynamic=True,
1318
+
1319
+ if pinAllModels or model_id in modelsToPin:
1320
+ if hasattr(current_model,"_already_pinned"):
1321
+ if self.verboseLevel >=1:
1322
+ print(f"Model '{model_id}' already pinned to reserved memory")
1323
+ else:
1324
+ _pin_to_memory(current_model, model_id, partialPinning= partialPinning, perc_reserved_mem_max=perc_reserved_mem_max, verboseLevel=verboseLevel)
1325
+
1326
+ for submodule_name, submodule in current_model.named_modules():
1327
+ # create a fake 'accelerate' parameter so that the _execution_device property returns always "cuda"
1328
+ # (it is queried in many pipelines even if offloading is not properly implemented)
1329
+ if not hasattr(submodule, "_hf_hook"):
1330
+ setattr(submodule, "_hf_hook", HfHook())
1331
+
1332
+ if submodule_name=='':
1333
+ continue
1334
+ newListItem = False
1335
+ if current_budget > 0:
1336
+ if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): #
1337
+ if cur_blocks_prefix == None:
1338
+ cur_blocks_prefix = submodule_name + "."
1339
+ else:
1340
+ #if cur_blocks_prefix != submodule_name[:len(cur_blocks_prefix)]:
1341
+ if not submodule_name.startswith(cur_blocks_prefix):
1342
+ cur_blocks_prefix = submodule_name + "."
1343
+ cur_blocks_name,cur_blocks_seq = None, -1
1344
+ else:
1345
+
1346
+ if cur_blocks_prefix is not None:
1347
+ if submodule_name.startswith(cur_blocks_prefix):
1348
+ num = int(submodule_name[len(cur_blocks_prefix):].split(".")[0])
1349
+ newListItem= num != cur_blocks_seq
1350
+ if num != cur_blocks_seq and (cur_blocks_name == None or current_size > current_budget):
1351
+ prev_blocks_name = cur_blocks_name
1352
+ cur_blocks_name = cur_blocks_prefix + str(num)
1353
+ # print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
1354
+ cur_blocks_seq = num
1355
+ else:
1356
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1357
+
1358
+ if hasattr(submodule, "forward"):
1359
+ submodule_method = getattr(submodule, "forward")
1360
+ if callable(submodule_method):
1361
+ if len(submodule_name.split("."))==1:
1362
+ self.hook_change_module(submodule, current_model, model_id, submodule_name, submodule_method)
1363
+ elif newListItem:
1364
+ self.hook_load_data_if_needed(submodule, model_id, cur_blocks_name, context = submodule_name )
1365
+ else:
1366
+ self.hook_check_empty_cache_needed(submodule, model_id, cur_blocks_name, submodule_method, context = submodule_name )
1367
+
1368
+
1369
+ current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
1370
+
1371
+
1372
+
1373
+
1374
+ if self.verboseLevel >=2:
1375
+ for n,b in self.blocks_of_modules_sizes.items():
1376
+ print(f"Size of submodel '{n}': {b/ONE_MB:.1f} MB")
1377
+
1378
+ torch.set_default_device('cuda')
1379
+ torch.cuda.empty_cache()
1380
+ gc.collect()
1381
+
1382
+ return self
1383
+
1384
+
1385
+ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.VerylowRAM_LowVRAM, verboseLevel = -1, **overrideKwargs):
1386
+ """Apply a configuration profile that depends on your hardware:
1387
+ pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
1388
+ profile_name : num of the profile:
1389
+ HighRAM_HighVRAM_Fastest (=1): at least 48 GB of RAM and 24 GB of VRAM : the fastest well suited for a RTX 3090 / RTX 4090
1390
+ HighRAM_LowVRAM_Fast (=2): at least 48 GB of RAM and 12 GB of VRAM : a bit slower, better suited for RTX 3070/3080/4070/4080
1391
+ or for RTX 3090 / RTX 4090 with large pictures batches or long videos
1392
+ LowRAM_HighVRAM_Medium (=3): at least 32 GB of RAM and 24 GB of VRAM : so so speed but adapted for RTX 3090 / RTX 4090 with limited RAM
1393
+ LowRAM_LowVRAM_Slow (=4): at least 32 GB of RAM and 12 GB of VRAM : if have little VRAM or generate longer videos
1394
+ VerylowRAM_LowVRAM_Slowest (=5): at least 24 GB of RAM and 10 GB of VRAM : if you don't have much it won't be fast but maybe it will work
1395
+ overrideKwargs: every parameter accepted by Offload.All can be added here to override the profile choice
1396
+ For instance set quantizeTransformer = False to disable transformer quantization which is by default in every profile
1397
+ """
1398
+
1399
+ _welcome()
1400
+
1401
+ verboseLevel = _compute_verbose_level(verboseLevel)
1402
+
1403
+ modules = pipe_or_dict_of_modules
1404
+
1405
+ if hasattr(modules, "components"):
1406
+ modules= modules.components
1407
+
1408
+ modules = {k: _remove_model_wrapper(v) for k, v in modules.items() if isinstance(v, torch.nn.Module)}
1409
+ module_names = {k: v.__module__.lower() for k, v in modules.items() }
1410
+
1411
+ default_extraModelsToQuantize = []
1412
+ quantizeTransformer = True
1413
+
1414
+ models_to_scan = ("text_encoder", "text_encoder_2")
1415
+ candidates_to_quantize = ("t5", "llama", "llm")
1416
+ for model_id in models_to_scan:
1417
+ name = module_names[model_id]
1418
+ for candidate in candidates_to_quantize:
1419
+ if candidate in name:
1420
+ default_extraModelsToQuantize.append(model_id)
1421
+ break
1422
+
1423
+
1424
+ # transformer (video or image generator) should be as small as possible not to occupy space that could be used by actual image data
1425
+ # on the other hand the text encoder should be quite large (as long as it fits in 10 GB of VRAM) to reduce sequence offloading
1426
+
1427
+ default_budgets = { "transformer" : 600 , "text_encoder": 3000, "text_encoder_2": 3000 }
1428
+ extraModelsToQuantize = None
1429
+ asyncTransfers = True
1430
+
1431
+ if profile_no == profile_type.HighRAM_HighVRAM:
1432
+ pinnedMemory= True
1433
+ budgets = None
1434
+ info = "You have chosen a profile that requires at least 48 GB of RAM and 24 GB of VRAM. Some VRAM is consuming just to make the model runs faster."
1435
+ elif profile_no == profile_type.HighRAM_LowVRAM:
1436
+ pinnedMemory= True
1437
+ budgets = default_budgets
1438
+ info = "You have chosen a profile that requires at least 48 GB of RAM and 12 GB of VRAM. Some RAM is consumed to reduce VRAM consumption."
1439
+ elif profile_no == profile_type.LowRAM_HighVRAM:
1440
+ pinnedMemory= "transformer"
1441
+ extraModelsToQuantize = default_extraModelsToQuantize
1442
+ info = "You have chosen a Medium speed profile that requires at least 32 GB of RAM and 24 GB of VRAM. Some VRAM is consuming just to make the model runs faster"
1443
+ elif profile_no == profile_type.LowRAM_LowVRAM:
1444
+ pinnedMemory= "transformer"
1445
+ extraModelsToQuantize = default_extraModelsToQuantize
1446
+ budgets=default_budgets
1447
+ info = "You have chosen a profile that requires at least 32 GB of RAM and 12 GB of VRAM. Some RAM is consumed to reduce VRAM consumption. "
1448
+ elif profile_no == profile_type.VerylowRAM_LowVRAM:
1449
+ pinnedMemory= False
1450
+ extraModelsToQuantize = default_extraModelsToQuantize
1451
+ budgets=default_budgets
1452
+ budgets["transformer"] = 400
1453
+ asyncTransfers = False
1454
+ info = "You have chosen the slowest profile that requires at least 24 GB of RAM and 10 GB of VRAM."
1455
+ else:
1456
+ raise Exception("Unknown profile")
1457
+ CrLf = '\r\n'
1458
+ kwargs = { "pinnedMemory": pinnedMemory, "extraModelsToQuantize" : extraModelsToQuantize, "budgets": budgets, "asyncTransfers" : asyncTransfers, "quantizeTransformer": quantizeTransformer }
1459
+
1460
+ if verboseLevel>=2:
1461
+ info = info + CrLf + f"Profile '{profile_type.tostr(profile_no)}' sets the following options:"
1462
+ for k,v in kwargs.items():
1463
+ if k in overrideKwargs:
1464
+ info = info + CrLf + f"- '{k}': '{kwargs[k]}' overriden with value '{overrideKwargs[k]}'"
1465
+ else:
1466
+ info = info + CrLf + f"- '{k}': '{kwargs[k]}'"
1467
+
1468
+ for k,v in overrideKwargs.items():
1469
+ kwargs[k] = overrideKwargs[k]
1470
+
1471
+ if info:
1472
+ print(info)
1473
+
1474
+ return all(pipe_or_dict_of_modules, verboseLevel = verboseLevel, **kwargs)