mmgp 2.0.3__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mmgp might be problematic. Click here for more details.
- mmgp/__init__.py +22 -0
- mmgp/offload.py +1472 -0
- mmgp/safetensors2.py +387 -0
- {mmgp-2.0.3.dist-info → mmgp-3.0.0.dist-info}/LICENSE.md +1 -1
- {mmgp-2.0.3.dist-info → mmgp-3.0.0.dist-info}/METADATA +155 -137
- mmgp-3.0.0.dist-info/RECORD +9 -0
- mmgp-2.0.3.dist-info/RECORD +0 -7
- mmgp.py +0 -951
- {mmgp-2.0.3.dist-info → mmgp-3.0.0.dist-info}/WHEEL +0 -0
- {mmgp-2.0.3.dist-info → mmgp-3.0.0.dist-info}/top_level.txt +0 -0
mmgp/offload.py
ADDED
|
@@ -0,0 +1,1472 @@
|
|
|
1
|
+
# ------------------ Memory Management 3.0 for the GPU Poor by DeepBeepMeep (mmgp)------------------
|
|
2
|
+
#
|
|
3
|
+
# This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
|
|
4
|
+
# This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
|
|
5
|
+
# times in a pipe (eg VAE).
|
|
6
|
+
#
|
|
7
|
+
# Requirements (for Linux, for Windows systems add 16 GB of RAM):
|
|
8
|
+
# - VRAM: minimum 12 GB, recommended 24 GB (RTX 3090/ RTX 4090)
|
|
9
|
+
# - RAM: minimum 24 GB, recommended 48 - 64 GB
|
|
10
|
+
#
|
|
11
|
+
# It is almost plug and play and just needs to be invoked from the main app just after the model pipeline has been created.
|
|
12
|
+
# Make sure that the pipeline explictly loads the models in the CPU device
|
|
13
|
+
# for instance: pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cpu")
|
|
14
|
+
# For a quick setup, you may want to choose between 5 profiles depending on your hardware, for instance:
|
|
15
|
+
# from mmgp import offload, profile_type
|
|
16
|
+
# offload.profile(pipe, profile_type.HighRAM_LowVRAM_Fast)
|
|
17
|
+
# Alternatively you may want to your own parameters, for instance:
|
|
18
|
+
# from mmgp import offload
|
|
19
|
+
# offload.all(pipe, pinToMemory=true, extraModelsToQuantize = ["text_encoder_2"] )
|
|
20
|
+
# The 'transformer' model that contains usually the video or image generator is quantized on the fly by default to 8 bits so that it can fit into 24 GB of VRAM.
|
|
21
|
+
# You can prevent the transformer quantization by adding the parameter quantizeTransformer = False
|
|
22
|
+
# If you want to save time on disk and reduce the loading time, you may want to load directly a prequantized model. In that case you need to set the option quantizeTransformer to False to turn off on the fly quantization.
|
|
23
|
+
# You can specify a list of additional models string ids to quantize (for instance the text_encoder) using the optional argument extraModelsToQuantize. This may be useful if you have less than 48 GB of RAM.
|
|
24
|
+
# Note that there is little advantage on the GPU / VRAM side to quantize text encoders as their inputs are usually quite light.
|
|
25
|
+
# Conversely if you have more than 48GB RAM you may want to enable RAM pinning with the option pinnedMemory = True. You will get in return super fast loading / unloading of models
|
|
26
|
+
# (this can save significant time if the same pipeline is run multiple times in a row)
|
|
27
|
+
#
|
|
28
|
+
# Sometime there isn't an explicit pipe object as each submodel is loaded separately in the main app. If this is the case, you need to create a dictionary that manually maps all the models.
|
|
29
|
+
#
|
|
30
|
+
# For instance :
|
|
31
|
+
# for flux derived models: pipe = { "text_encoder": clip, "text_encoder_2": t5, "transformer": model, "vae":ae }
|
|
32
|
+
# for mochi: pipe = { "text_encoder": self.text_encoder, "transformer": self.dit, "vae":self.decoder }
|
|
33
|
+
#
|
|
34
|
+
# Please note that there should be always one model whose Id is 'transformer'. It corresponds to the main image / video model which usually needs to be quantized (this is done on the fly by default when loading the model)
|
|
35
|
+
#
|
|
36
|
+
# Becareful, lots of models use the T5 XXL as a text encoder. However, quite often their corresponding pipeline configurations point at the official Google T5 XXL repository
|
|
37
|
+
# where there is a huge 40GB model to download and load. It is cumbersorme as it is a 32 bits model and contains the decoder part of T5 that is not used.
|
|
38
|
+
# I suggest you use instead one of the 16 bits encoder only version available around, for instance:
|
|
39
|
+
# text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2", torch_dtype=torch.float16)
|
|
40
|
+
#
|
|
41
|
+
# Sometime just providing the pipe won't be sufficient as you will need to change the content of the core model:
|
|
42
|
+
# - For instance you may need to disable an existing CPU offload logic that already exists (such as manual calls to move tensors between cuda and the cpu)
|
|
43
|
+
# - mmpg to tries to fake the device as being "cuda" but sometimes some code won't be fooled and it will create tensors in the cpu device and this may cause some issues.
|
|
44
|
+
#
|
|
45
|
+
# You are free to use my module for non commercial use as long you give me proper credits. You may contact me on twitter @deepbeepmeep
|
|
46
|
+
#
|
|
47
|
+
# Thanks to
|
|
48
|
+
# ---------
|
|
49
|
+
# Huggingface / accelerate for the hooking examples
|
|
50
|
+
# Huggingface / quanto for their very useful quantizer
|
|
51
|
+
# gau-nernst for his Pinnig RAM samples
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
#
|
|
55
|
+
|
|
56
|
+
import torch
|
|
57
|
+
import gc
|
|
58
|
+
import time
|
|
59
|
+
import functools
|
|
60
|
+
import sys
|
|
61
|
+
import os
|
|
62
|
+
import json
|
|
63
|
+
import psutil
|
|
64
|
+
from mmgp import safetensors2
|
|
65
|
+
from mmgp import profile_type
|
|
66
|
+
|
|
67
|
+
from optimum.quanto import freeze, qfloat8, qint8, quantize, QModuleMixin, QTensor, WeightQBytesTensor, quantize_module
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
mmm = safetensors2.mmm
|
|
73
|
+
|
|
74
|
+
ONE_MB = 1048576
|
|
75
|
+
sizeofbfloat16 = torch.bfloat16.itemsize
|
|
76
|
+
sizeofint8 = torch.int8.itemsize
|
|
77
|
+
total_pinned_bytes = 0
|
|
78
|
+
physical_memory= psutil.virtual_memory().total
|
|
79
|
+
|
|
80
|
+
HEADER = '\033[95m'
|
|
81
|
+
ENDC = '\033[0m'
|
|
82
|
+
BOLD ='\033[1m'
|
|
83
|
+
UNBOLD ='\033[0m'
|
|
84
|
+
|
|
85
|
+
cotenants_map = {
|
|
86
|
+
"text_encoder": ["vae", "text_encoder_2"],
|
|
87
|
+
"text_encoder_2": ["vae", "text_encoder"],
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
class clock:
|
|
91
|
+
def __init__(self):
|
|
92
|
+
self.start_time = 0
|
|
93
|
+
self.end_time = 0
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def start(cls):
|
|
97
|
+
self = cls()
|
|
98
|
+
self.start_time =time.time()
|
|
99
|
+
return self
|
|
100
|
+
|
|
101
|
+
def stop(self):
|
|
102
|
+
self.stop_time =time.time()
|
|
103
|
+
|
|
104
|
+
def time_gap(self):
|
|
105
|
+
return self.stop_time - self.start_time
|
|
106
|
+
|
|
107
|
+
def format_time_gap(self):
|
|
108
|
+
return f"{self.stop_time - self.start_time:.2f}s"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# useful functions to move a group of tensors (to design custom offload patches)
|
|
113
|
+
def move_tensors(obj, device):
|
|
114
|
+
if torch.is_tensor(obj):
|
|
115
|
+
return obj.to(device)
|
|
116
|
+
elif isinstance(obj, dict):
|
|
117
|
+
_dict = {}
|
|
118
|
+
for k, v in obj.items():
|
|
119
|
+
_dict[k] = move_tensors(v, device)
|
|
120
|
+
return _dict
|
|
121
|
+
elif isinstance(obj, list):
|
|
122
|
+
_list = []
|
|
123
|
+
for v in obj:
|
|
124
|
+
_list.append(move_tensors(v, device))
|
|
125
|
+
return _list
|
|
126
|
+
else:
|
|
127
|
+
raise TypeError("Tensor or list / dict of tensors expected")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _get_max_reservable_memory(perc_reserved_mem_max):
|
|
131
|
+
if perc_reserved_mem_max<=0:
|
|
132
|
+
perc_reserved_mem_max = 0.40 if os.name == 'nt' else 0.5
|
|
133
|
+
return perc_reserved_mem_max * physical_memory
|
|
134
|
+
|
|
135
|
+
def _detect_main_towers(model, verboseLevel=1):
|
|
136
|
+
cur_blocks_prefix = None
|
|
137
|
+
towers_modules= []
|
|
138
|
+
towers_names= []
|
|
139
|
+
|
|
140
|
+
for submodule_name, submodule in model.named_modules():
|
|
141
|
+
if submodule_name=='':
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
if isinstance(submodule, torch.nn.ModuleList):
|
|
145
|
+
newList =False
|
|
146
|
+
if cur_blocks_prefix == None:
|
|
147
|
+
cur_blocks_prefix = submodule_name + "."
|
|
148
|
+
newList = True
|
|
149
|
+
else:
|
|
150
|
+
if not submodule_name.startswith(cur_blocks_prefix):
|
|
151
|
+
cur_blocks_prefix = submodule_name + "."
|
|
152
|
+
newList = True
|
|
153
|
+
|
|
154
|
+
if newList and len(submodule)>=5:
|
|
155
|
+
towers_names.append(submodule_name)
|
|
156
|
+
towers_modules.append(submodule)
|
|
157
|
+
|
|
158
|
+
else:
|
|
159
|
+
if cur_blocks_prefix is not None:
|
|
160
|
+
if not submodule_name.startswith(cur_blocks_prefix):
|
|
161
|
+
cur_blocks_prefix = None
|
|
162
|
+
|
|
163
|
+
return towers_names, towers_modules
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _get_model(model_path):
|
|
168
|
+
if os.path.isfile(model_path):
|
|
169
|
+
return model_path
|
|
170
|
+
|
|
171
|
+
from pathlib import Path
|
|
172
|
+
_path = Path(model_path).parts
|
|
173
|
+
_filename = _path[-1]
|
|
174
|
+
_path = _path[:-1]
|
|
175
|
+
if len(_path)==1:
|
|
176
|
+
raise("file not found")
|
|
177
|
+
else:
|
|
178
|
+
from huggingface_hub import hf_hub_download #snapshot_download,
|
|
179
|
+
repoId= os.path.join(*_path[0:2] ).replace("\\", "/")
|
|
180
|
+
|
|
181
|
+
if len(_path) > 2:
|
|
182
|
+
_subfolder = os.path.join(*_path[2:] )
|
|
183
|
+
model_path = hf_hub_download(repo_id=repoId, filename=_filename, subfolder=_subfolder)
|
|
184
|
+
else:
|
|
185
|
+
model_path = hf_hub_download(repo_id=repoId, filename=_filename)
|
|
186
|
+
|
|
187
|
+
return model_path
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _remove_model_wrapper(model):
|
|
192
|
+
if not model._modules is None:
|
|
193
|
+
if len(model._modules)!=1:
|
|
194
|
+
return model
|
|
195
|
+
sub_module = model._modules[next(iter(model._modules))]
|
|
196
|
+
if hasattr(sub_module,"config") or hasattr(sub_module,"base_model"):
|
|
197
|
+
return sub_module
|
|
198
|
+
return model
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _move_to_pinned_tensor(source_tensor, big_tensor, offset, length):
|
|
203
|
+
dtype= source_tensor.dtype
|
|
204
|
+
shape = source_tensor.shape
|
|
205
|
+
if len(shape) == 0:
|
|
206
|
+
return source_tensor
|
|
207
|
+
else:
|
|
208
|
+
t = source_tensor.view(torch.uint8)
|
|
209
|
+
t = torch.reshape(t, (length,))
|
|
210
|
+
# magic swap !
|
|
211
|
+
big_tensor[offset: offset + length] = t
|
|
212
|
+
t = big_tensor[offset: offset + length]
|
|
213
|
+
t = t.view(dtype)
|
|
214
|
+
t = torch.reshape(t, shape)
|
|
215
|
+
assert t.is_pinned()
|
|
216
|
+
return t
|
|
217
|
+
|
|
218
|
+
def _safetensors_load_file(file_path):
|
|
219
|
+
from collections import OrderedDict
|
|
220
|
+
sd = OrderedDict()
|
|
221
|
+
|
|
222
|
+
with safetensors2.safe_open(file_path, framework="pt", device="cpu") as f:
|
|
223
|
+
for k in f.keys():
|
|
224
|
+
sd[k] = f.get_tensor(k)
|
|
225
|
+
metadata = f.metadata()
|
|
226
|
+
|
|
227
|
+
return sd, metadata
|
|
228
|
+
|
|
229
|
+
def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_max = 0, verboseLevel = 1):
|
|
230
|
+
if verboseLevel>=1 :
|
|
231
|
+
if partialPinning:
|
|
232
|
+
print(f"Partial pinning to RAM of data of '{model_id}'")
|
|
233
|
+
else:
|
|
234
|
+
print(f"Pinning data to RAM of '{model_id}'")
|
|
235
|
+
|
|
236
|
+
max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
|
|
237
|
+
if partialPinning:
|
|
238
|
+
towers_names, _ = _detect_main_towers(model)
|
|
239
|
+
towers_names = [n +"." for n in towers_names]
|
|
240
|
+
|
|
241
|
+
BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
|
|
242
|
+
current_big_tensor_size = 0
|
|
243
|
+
big_tensor_no = 0
|
|
244
|
+
big_tensors_sizes = []
|
|
245
|
+
tensor_map_indexes = []
|
|
246
|
+
total_tensor_bytes = 0
|
|
247
|
+
|
|
248
|
+
params_list = []
|
|
249
|
+
for k, sub_module in model.named_modules():
|
|
250
|
+
include = True
|
|
251
|
+
if partialPinning:
|
|
252
|
+
include = any(k.startswith(pre) for pre in towers_names) if partialPinning else True
|
|
253
|
+
if include:
|
|
254
|
+
params_list = params_list + list(sub_module.buffers(recurse=False)) + list(sub_module.parameters(recurse=False))
|
|
255
|
+
|
|
256
|
+
for p in params_list:
|
|
257
|
+
if isinstance(p, QTensor):
|
|
258
|
+
length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
|
|
259
|
+
else:
|
|
260
|
+
length = torch.numel(p.data) * p.data.element_size()
|
|
261
|
+
|
|
262
|
+
if current_big_tensor_size + length > BIG_TENSOR_MAX_SIZE:
|
|
263
|
+
big_tensors_sizes.append(current_big_tensor_size)
|
|
264
|
+
current_big_tensor_size = 0
|
|
265
|
+
big_tensor_no += 1
|
|
266
|
+
tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
|
|
267
|
+
current_big_tensor_size += length
|
|
268
|
+
|
|
269
|
+
total_tensor_bytes += length
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
big_tensors_sizes.append(current_big_tensor_size)
|
|
273
|
+
|
|
274
|
+
big_tensors = []
|
|
275
|
+
last_big_tensor = 0
|
|
276
|
+
total = 0
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
for size in big_tensors_sizes:
|
|
281
|
+
try:
|
|
282
|
+
current_big_tensor = torch.empty( size, dtype= torch.uint8, pin_memory=True, device="cpu")
|
|
283
|
+
big_tensors.append(current_big_tensor)
|
|
284
|
+
except:
|
|
285
|
+
print(f"Unable to pin more tensors for this model as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
|
|
286
|
+
break
|
|
287
|
+
|
|
288
|
+
last_big_tensor += 1
|
|
289
|
+
total += size
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
gc.collect()
|
|
293
|
+
|
|
294
|
+
tensor_no = 0
|
|
295
|
+
for p in params_list:
|
|
296
|
+
big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
|
|
297
|
+
|
|
298
|
+
if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
|
|
299
|
+
current_big_tensor = big_tensors[big_tensor_no]
|
|
300
|
+
if isinstance(p, QTensor):
|
|
301
|
+
length1 = torch.numel(p._data) * p._data.element_size()
|
|
302
|
+
p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
|
|
303
|
+
length2 = torch.numel(p._scale) * p._scale.element_size()
|
|
304
|
+
p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
|
|
305
|
+
else:
|
|
306
|
+
length = torch.numel(p.data) * p.data.element_size()
|
|
307
|
+
p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
|
|
308
|
+
|
|
309
|
+
tensor_no += 1
|
|
310
|
+
global total_pinned_bytes
|
|
311
|
+
total_pinned_bytes += total
|
|
312
|
+
gc.collect()
|
|
313
|
+
|
|
314
|
+
if verboseLevel >=1:
|
|
315
|
+
if total_tensor_bytes == total:
|
|
316
|
+
print(f"The whole model was pinned to RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
|
|
317
|
+
else:
|
|
318
|
+
print(f"{total/ONE_MB:.2f} MB were pinned to RAM out of {total_tensor_bytes/ONE_MB:.2f} MB")
|
|
319
|
+
|
|
320
|
+
model._already_pinned = True
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
return
|
|
324
|
+
welcome_displayed = False
|
|
325
|
+
|
|
326
|
+
def _welcome():
|
|
327
|
+
global welcome_displayed
|
|
328
|
+
if welcome_displayed:
|
|
329
|
+
return
|
|
330
|
+
welcome_displayed = True
|
|
331
|
+
print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.0) by DeepBeepMeep ************{ENDC}{UNBOLD}")
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
# def _pin_to_memory_sd(model, sd, model_id, partialPinning = False, perc_reserved_mem_max = 0, verboseLevel = 1):
|
|
335
|
+
# if verboseLevel>=1 :
|
|
336
|
+
# if partialPinning:
|
|
337
|
+
# print(f"Partial pinning to RAM of data of file '{model_id}' while loading it")
|
|
338
|
+
# else:
|
|
339
|
+
# print(f"Pinning data to RAM of file '{model_id}' while loading it")
|
|
340
|
+
|
|
341
|
+
# max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
|
|
342
|
+
# if partialPinning:
|
|
343
|
+
# towers_names, _ = _detect_main_towers(model)
|
|
344
|
+
# towers_names = [n +"." for n in towers_names]
|
|
345
|
+
|
|
346
|
+
# BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
|
|
347
|
+
# current_big_tensor_size = 0
|
|
348
|
+
# big_tensor_no = 0
|
|
349
|
+
# big_tensors_sizes = []
|
|
350
|
+
# tensor_map_indexes = []
|
|
351
|
+
# total_tensor_bytes = 0
|
|
352
|
+
|
|
353
|
+
# for k,t in sd.items():
|
|
354
|
+
# include = True
|
|
355
|
+
# # if isinstance(p, QTensor):
|
|
356
|
+
# # length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
|
|
357
|
+
# # else:
|
|
358
|
+
# # length = torch.numel(p.data) * p.data.element_size()
|
|
359
|
+
# length = torch.numel(t) * t.data.element_size()
|
|
360
|
+
|
|
361
|
+
# if partialPinning:
|
|
362
|
+
# include = any(k.startswith(pre) for pre in towers_names) if partialPinning else True
|
|
363
|
+
|
|
364
|
+
# if include:
|
|
365
|
+
# if current_big_tensor_size + length > BIG_TENSOR_MAX_SIZE:
|
|
366
|
+
# big_tensors_sizes.append(current_big_tensor_size)
|
|
367
|
+
# current_big_tensor_size = 0
|
|
368
|
+
# big_tensor_no += 1
|
|
369
|
+
# tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
|
|
370
|
+
# current_big_tensor_size += length
|
|
371
|
+
# else:
|
|
372
|
+
# tensor_map_indexes.append((-1, 0, 0 ))
|
|
373
|
+
# total_tensor_bytes += length
|
|
374
|
+
|
|
375
|
+
# big_tensors_sizes.append(current_big_tensor_size)
|
|
376
|
+
|
|
377
|
+
# big_tensors = []
|
|
378
|
+
# last_big_tensor = 0
|
|
379
|
+
# total = 0
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
# for size in big_tensors_sizes:
|
|
383
|
+
# try:
|
|
384
|
+
# currrent_big_tensor = torch.empty( size, dtype= torch.uint8, pin_memory=True)
|
|
385
|
+
# big_tensors.append(currrent_big_tensor)
|
|
386
|
+
# except:
|
|
387
|
+
# print(f"Unable to pin more tensors for this model as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
|
|
388
|
+
# break
|
|
389
|
+
|
|
390
|
+
# last_big_tensor += 1
|
|
391
|
+
# total += size
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
# tensor_no = 0
|
|
395
|
+
# for k,t in sd.items():
|
|
396
|
+
# big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
|
|
397
|
+
# if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
|
|
398
|
+
# current_big_tensor = big_tensors[big_tensor_no]
|
|
399
|
+
# # if isinstance(p, QTensor):
|
|
400
|
+
# # length1 = torch.numel(p._data) * p._data.element_size()
|
|
401
|
+
# # p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
|
|
402
|
+
# # length2 = torch.numel(p._scale) * p._scale.element_size()
|
|
403
|
+
# # p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
|
|
404
|
+
# # else:
|
|
405
|
+
# # length = torch.numel(p.data) * p.data.element_size()
|
|
406
|
+
# # p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
|
|
407
|
+
# length = torch.numel(t) * t.data.element_size()
|
|
408
|
+
# t = _move_to_pinned_tensor(t, current_big_tensor, offset, length)
|
|
409
|
+
# sd[k] = t
|
|
410
|
+
# tensor_no += 1
|
|
411
|
+
|
|
412
|
+
# global total_pinned_bytes
|
|
413
|
+
# total_pinned_bytes += total
|
|
414
|
+
|
|
415
|
+
# if verboseLevel >=1:
|
|
416
|
+
# if total_tensor_bytes == total:
|
|
417
|
+
# print(f"The whole model was pinned to RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
|
|
418
|
+
# else:
|
|
419
|
+
# print(f"{total/ONE_MB:.2f} MB were pinned to RAM out of {total_tensor_bytes/ONE_MB:.2f} MB")
|
|
420
|
+
|
|
421
|
+
# model._already_pinned = True
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
# return
|
|
425
|
+
|
|
426
|
+
def _quantize_dirty_hack(model):
|
|
427
|
+
# dirty hack: add a hook on state_dict() to return a fake non quantized state_dict if called by Lora Diffusers initialization functions
|
|
428
|
+
setattr( model, "_real_state_dict", model.state_dict)
|
|
429
|
+
from collections import OrderedDict
|
|
430
|
+
import traceback
|
|
431
|
+
|
|
432
|
+
def state_dict_for_lora(self):
|
|
433
|
+
real_sd = self._real_state_dict()
|
|
434
|
+
fakeit = False
|
|
435
|
+
stack = traceback.extract_stack(f=None, limit=5)
|
|
436
|
+
for frame in stack:
|
|
437
|
+
if "_lora_" in frame.name:
|
|
438
|
+
fakeit = True
|
|
439
|
+
break
|
|
440
|
+
|
|
441
|
+
if not fakeit:
|
|
442
|
+
return real_sd
|
|
443
|
+
sd = OrderedDict()
|
|
444
|
+
for k in real_sd:
|
|
445
|
+
v = real_sd[k]
|
|
446
|
+
if k.endswith("._data"):
|
|
447
|
+
k = k[:len(k)-6]
|
|
448
|
+
sd[k] = v
|
|
449
|
+
return sd
|
|
450
|
+
|
|
451
|
+
setattr(model, "state_dict", functools.update_wrapper(functools.partial(state_dict_for_lora, model), model.state_dict) )
|
|
452
|
+
|
|
453
|
+
def _quantization_map(model):
|
|
454
|
+
from optimum.quanto import quantization_map
|
|
455
|
+
return quantization_map(model)
|
|
456
|
+
|
|
457
|
+
def _set_module_by_name(parent_module, name, child_module):
|
|
458
|
+
module_names = name.split(".")
|
|
459
|
+
if len(module_names) == 1:
|
|
460
|
+
setattr(parent_module, name, child_module)
|
|
461
|
+
else:
|
|
462
|
+
parent_module_name = name[: name.rindex(".")]
|
|
463
|
+
parent_module = parent_module.get_submodule(parent_module_name)
|
|
464
|
+
setattr(parent_module, module_names[-1], child_module)
|
|
465
|
+
|
|
466
|
+
def _quantize_submodule(
|
|
467
|
+
model: torch.nn.Module,
|
|
468
|
+
name: str,
|
|
469
|
+
module: torch.nn.Module,
|
|
470
|
+
weights = None,
|
|
471
|
+
activations = None,
|
|
472
|
+
optimizer = None,
|
|
473
|
+
):
|
|
474
|
+
|
|
475
|
+
qmodule = quantize_module(module, weights=weights, activations=activations, optimizer=optimizer)
|
|
476
|
+
if qmodule is not None:
|
|
477
|
+
_set_module_by_name(model, name, qmodule)
|
|
478
|
+
qmodule.name = name
|
|
479
|
+
for name, param in module.named_parameters():
|
|
480
|
+
# Save device memory by clearing parameters
|
|
481
|
+
setattr(module, name, None)
|
|
482
|
+
del param
|
|
483
|
+
|
|
484
|
+
def _requantize(model: torch.nn.Module, state_dict: dict, quantization_map: dict):
|
|
485
|
+
# change dtype of current meta model parameters because 'requantize' won't update the dtype on non quantized parameters
|
|
486
|
+
for k, p in model.named_parameters():
|
|
487
|
+
if not k in quantization_map and k in state_dict:
|
|
488
|
+
p_in_file = state_dict[k]
|
|
489
|
+
if p.data.dtype != p_in_file.data.dtype:
|
|
490
|
+
p.data = p.data.to(p_in_file.data.dtype)
|
|
491
|
+
|
|
492
|
+
# rebuild quanto objects
|
|
493
|
+
for name, m in model.named_modules():
|
|
494
|
+
qconfig = quantization_map.get(name, None)
|
|
495
|
+
if qconfig is not None:
|
|
496
|
+
weights = qconfig["weights"]
|
|
497
|
+
if weights == "none":
|
|
498
|
+
weights = None
|
|
499
|
+
activations = qconfig["activations"]
|
|
500
|
+
if activations == "none":
|
|
501
|
+
activations = None
|
|
502
|
+
_quantize_submodule(model, name, m, weights=weights, activations=activations)
|
|
503
|
+
|
|
504
|
+
model._quanto_map = quantization_map
|
|
505
|
+
|
|
506
|
+
_quantize_dirty_hack(model)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 1000000000, model_id = 'Unknown'):
|
|
511
|
+
|
|
512
|
+
def compute_submodule_size(submodule):
|
|
513
|
+
size = 0
|
|
514
|
+
for p in submodule.parameters(recurse=False):
|
|
515
|
+
size += torch.numel(p.data) * sizeofbfloat16
|
|
516
|
+
|
|
517
|
+
for p in submodule.buffers(recurse=False):
|
|
518
|
+
size += torch.numel(p.data) * sizeofbfloat16
|
|
519
|
+
|
|
520
|
+
return size
|
|
521
|
+
|
|
522
|
+
total_size =0
|
|
523
|
+
total_excluded = 0
|
|
524
|
+
exclude_list = []
|
|
525
|
+
submodule_size = 0
|
|
526
|
+
submodule_names = []
|
|
527
|
+
cur_blocks_prefix = None
|
|
528
|
+
prev_blocks_prefix = None
|
|
529
|
+
|
|
530
|
+
if hasattr(model_to_quantize, "_quanto_map"):
|
|
531
|
+
print(f"Model '{model_id}' is already quantized")
|
|
532
|
+
return False
|
|
533
|
+
|
|
534
|
+
print(f"Quantization of model '{model_id}' started")
|
|
535
|
+
|
|
536
|
+
for submodule_name, submodule in model_to_quantize.named_modules():
|
|
537
|
+
if isinstance(submodule, QModuleMixin):
|
|
538
|
+
if verboseLevel>=1:
|
|
539
|
+
print("No quantization to do as model is already quantized")
|
|
540
|
+
return False
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
if submodule_name=='':
|
|
544
|
+
continue
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
flush = False
|
|
548
|
+
if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
|
549
|
+
if cur_blocks_prefix == None:
|
|
550
|
+
cur_blocks_prefix = submodule_name + "."
|
|
551
|
+
flush = True
|
|
552
|
+
else:
|
|
553
|
+
#if cur_blocks_prefix != submodule_name[:len(cur_blocks_prefix)]:
|
|
554
|
+
if not submodule_name.startswith(cur_blocks_prefix):
|
|
555
|
+
cur_blocks_prefix = submodule_name + "."
|
|
556
|
+
flush = True
|
|
557
|
+
else:
|
|
558
|
+
if cur_blocks_prefix is not None:
|
|
559
|
+
#if not cur_blocks_prefix == submodule_name[0:len(cur_blocks_prefix)]:
|
|
560
|
+
if not submodule_name.startswith(cur_blocks_prefix):
|
|
561
|
+
cur_blocks_prefix = None
|
|
562
|
+
flush = True
|
|
563
|
+
|
|
564
|
+
if flush:
|
|
565
|
+
if submodule_size <= threshold:
|
|
566
|
+
exclude_list += submodule_names
|
|
567
|
+
if verboseLevel >=2:
|
|
568
|
+
print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
|
|
569
|
+
total_excluded += submodule_size
|
|
570
|
+
|
|
571
|
+
submodule_size = 0
|
|
572
|
+
submodule_names = []
|
|
573
|
+
prev_blocks_prefix = cur_blocks_prefix
|
|
574
|
+
size = compute_submodule_size(submodule)
|
|
575
|
+
submodule_size += size
|
|
576
|
+
total_size += size
|
|
577
|
+
submodule_names.append(submodule_name)
|
|
578
|
+
|
|
579
|
+
if submodule_size > 0 and submodule_size <= threshold:
|
|
580
|
+
exclude_list += submodule_names
|
|
581
|
+
if verboseLevel >=2:
|
|
582
|
+
print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
|
|
583
|
+
total_excluded += submodule_size
|
|
584
|
+
|
|
585
|
+
perc_excluded =total_excluded/ total_size if total_size >0 else 1
|
|
586
|
+
if verboseLevel >=2:
|
|
587
|
+
print(f"Total Excluded {total_excluded/ONE_MB:.1f} MB oF {total_size/ONE_MB:.1f} that is {perc_excluded*100:.2f}%")
|
|
588
|
+
if perc_excluded >= 0.10:
|
|
589
|
+
print(f"Too many many modules are excluded, there is something wrong with the selection, switch back to full quantization.")
|
|
590
|
+
exclude_list = None
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
#quantize(model_to_quantize,weights, exclude= exclude_list)
|
|
594
|
+
pass
|
|
595
|
+
for name, m in model_to_quantize.named_modules():
|
|
596
|
+
if exclude_list is None or not any( name == module_name for module_name in exclude_list):
|
|
597
|
+
_quantize_submodule(model_to_quantize, name, m, weights=weights, activations=None, optimizer=None)
|
|
598
|
+
|
|
599
|
+
# force read non quantized parameters so that their lazy tensors and corresponding mmap are released
|
|
600
|
+
# otherwise we may end up to keep in memory both the quantized and the non quantize model
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
for name, m in model_to_quantize.named_modules():
|
|
604
|
+
# do not read quantized weights (detected them directly or behind an adapter)
|
|
605
|
+
if isinstance(m, QModuleMixin) or hasattr(m, "base_layer") and isinstance(m.base_layer, QModuleMixin):
|
|
606
|
+
pass
|
|
607
|
+
else:
|
|
608
|
+
if hasattr(m, "weight") and m.weight is not None:
|
|
609
|
+
m.weight.data = m.weight.data + 0
|
|
610
|
+
|
|
611
|
+
if hasattr(m, "bias") and m.bias is not None:
|
|
612
|
+
m.bias.data = m.bias.data + 0
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
freeze(model_to_quantize)
|
|
616
|
+
torch.cuda.empty_cache()
|
|
617
|
+
gc.collect()
|
|
618
|
+
quantization_map = _quantization_map(model_to_quantize)
|
|
619
|
+
model_to_quantize._quanto_map = quantization_map
|
|
620
|
+
|
|
621
|
+
_quantize_dirty_hack(model_to_quantize)
|
|
622
|
+
|
|
623
|
+
print(f"Quantization of model '{model_id}' done")
|
|
624
|
+
|
|
625
|
+
return True
|
|
626
|
+
|
|
627
|
+
def get_model_name(model):
|
|
628
|
+
return model.name
|
|
629
|
+
|
|
630
|
+
class HfHook:
|
|
631
|
+
def __init__(self):
|
|
632
|
+
self.execution_device = "cuda"
|
|
633
|
+
|
|
634
|
+
def detach_hook(self, module):
|
|
635
|
+
pass
|
|
636
|
+
|
|
637
|
+
class offload:
|
|
638
|
+
def __init__(self):
|
|
639
|
+
self.active_models = []
|
|
640
|
+
self.active_models_ids = []
|
|
641
|
+
self.active_subcaches = {}
|
|
642
|
+
self.models = {}
|
|
643
|
+
self.verboseLevel = 0
|
|
644
|
+
self.modules_data = {}
|
|
645
|
+
self.blocks_of_modules = {}
|
|
646
|
+
self.blocks_of_modules_sizes = {}
|
|
647
|
+
self.anyCompiledModule = False
|
|
648
|
+
self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
|
|
649
|
+
self.last_reserved_mem_check =0
|
|
650
|
+
self.loaded_blocks = {}
|
|
651
|
+
self.prev_blocks_names = {}
|
|
652
|
+
self.next_blocks_names = {}
|
|
653
|
+
self.default_stream = torch.cuda.default_stream(torch.device("cuda")) # torch.cuda.current_stream()
|
|
654
|
+
self.transfer_stream = torch.cuda.Stream()
|
|
655
|
+
self.async_transfers = False
|
|
656
|
+
|
|
657
|
+
def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name):
|
|
658
|
+
|
|
659
|
+
entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
|
|
660
|
+
if entry_name in self.blocks_of_modules:
|
|
661
|
+
blocks_params = self.blocks_of_modules[entry_name]
|
|
662
|
+
blocks_params_size = self.blocks_of_modules_sizes[entry_name]
|
|
663
|
+
else:
|
|
664
|
+
blocks_params = []
|
|
665
|
+
self.blocks_of_modules[entry_name] = blocks_params
|
|
666
|
+
blocks_params_size = 0
|
|
667
|
+
if blocks_name !=None:
|
|
668
|
+
|
|
669
|
+
prev_entry_name = None if prev_block_name == None else model_id + "/" + prev_block_name
|
|
670
|
+
self.prev_blocks_names[entry_name] = prev_entry_name
|
|
671
|
+
if not prev_block_name == None:
|
|
672
|
+
self.next_blocks_names[prev_entry_name] = entry_name
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
for k,p in submodule.named_parameters(recurse=False):
|
|
676
|
+
blocks_params.append(p)
|
|
677
|
+
if isinstance(p, QTensor):
|
|
678
|
+
blocks_params_size += p._data.nbytes
|
|
679
|
+
blocks_params_size += p._scale.nbytes
|
|
680
|
+
else:
|
|
681
|
+
blocks_params_size += p.data.nbytes
|
|
682
|
+
|
|
683
|
+
for p in submodule.buffers(recurse=False):
|
|
684
|
+
blocks_params.append(p)
|
|
685
|
+
blocks_params_size += p.data.nbytes
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
self.blocks_of_modules_sizes[entry_name] = blocks_params_size
|
|
689
|
+
|
|
690
|
+
return blocks_params_size
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def can_model_be_cotenant(self, model_id):
|
|
694
|
+
potential_cotenants= cotenants_map.get(model_id, None)
|
|
695
|
+
if potential_cotenants is None:
|
|
696
|
+
return False
|
|
697
|
+
for existing_cotenant in self.active_models_ids:
|
|
698
|
+
if existing_cotenant not in potential_cotenants:
|
|
699
|
+
return False
|
|
700
|
+
return True
|
|
701
|
+
|
|
702
|
+
def gpu_load_blocks(self, model_id, blocks_name, async_load = False):
|
|
703
|
+
# cl = clock.start()
|
|
704
|
+
import weakref
|
|
705
|
+
|
|
706
|
+
if blocks_name != None:
|
|
707
|
+
self.loaded_blocks[model_id] = blocks_name
|
|
708
|
+
|
|
709
|
+
def cpu_to_gpu(stream_to_use, blocks_params, record_for_stream = None):
|
|
710
|
+
with torch.cuda.stream(stream_to_use):
|
|
711
|
+
for p in blocks_params:
|
|
712
|
+
if isinstance(p, QTensor):
|
|
713
|
+
# need formal transfer to cuda otherwise quantized tensor will be still considered in cpu and compilation will fail
|
|
714
|
+
q=p.to("cuda",non_blocking=True)
|
|
715
|
+
#q = torch.nn.Parameter(q , requires_grad=False)
|
|
716
|
+
|
|
717
|
+
ref = weakref.getweakrefs(p)
|
|
718
|
+
if ref:
|
|
719
|
+
torch._C._swap_tensor_impl(p, q)
|
|
720
|
+
else:
|
|
721
|
+
torch.utils.swap_tensors(p, q)
|
|
722
|
+
|
|
723
|
+
# p._data = p._data.cuda(non_blocking=True)
|
|
724
|
+
# p._scale = p._scale.cuda(non_blocking=True)
|
|
725
|
+
else:
|
|
726
|
+
p.data = p.data.cuda(non_blocking=True)
|
|
727
|
+
|
|
728
|
+
if record_for_stream != None:
|
|
729
|
+
if isinstance(p, QTensor):
|
|
730
|
+
p._data.record_stream(record_for_stream)
|
|
731
|
+
p._scale.record_stream(record_for_stream)
|
|
732
|
+
else:
|
|
733
|
+
p.data.record_stream(record_for_stream)
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
|
|
737
|
+
if self.verboseLevel >=2:
|
|
738
|
+
model = self.models[model_id]
|
|
739
|
+
model_name = model._get_name()
|
|
740
|
+
print(f"Loading model {entry_name} ({model_name}) in GPU")
|
|
741
|
+
|
|
742
|
+
|
|
743
|
+
if self.async_transfers and blocks_name != None:
|
|
744
|
+
first = self.prev_blocks_names[entry_name] == None
|
|
745
|
+
next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None
|
|
746
|
+
if first:
|
|
747
|
+
cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name])
|
|
748
|
+
torch.cuda.synchronize()
|
|
749
|
+
|
|
750
|
+
if next_blocks_entry != None:
|
|
751
|
+
cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry]) #, self.default_stream
|
|
752
|
+
|
|
753
|
+
else:
|
|
754
|
+
cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name])
|
|
755
|
+
torch.cuda.synchronize()
|
|
756
|
+
# cl.stop()
|
|
757
|
+
# print(f"load time: {cl.format_time_gap()}")
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def gpu_unload_blocks(self, model_id, blocks_name):
|
|
761
|
+
# cl = clock.start()
|
|
762
|
+
import weakref
|
|
763
|
+
if blocks_name != None:
|
|
764
|
+
self.loaded_blocks[model_id] = None
|
|
765
|
+
|
|
766
|
+
blocks_name = model_id if blocks_name is None else model_id + "/" + blocks_name
|
|
767
|
+
|
|
768
|
+
if self.verboseLevel >=2:
|
|
769
|
+
model = self.models[model_id]
|
|
770
|
+
model_name = model._get_name()
|
|
771
|
+
print(f"Unloading model {blocks_name} ({model_name}) from GPU")
|
|
772
|
+
|
|
773
|
+
blocks_params = self.blocks_of_modules[blocks_name]
|
|
774
|
+
if "transformer/double_blocks.0" == blocks_name:
|
|
775
|
+
pass
|
|
776
|
+
parameters_data = self.modules_data[model_id]
|
|
777
|
+
for p in blocks_params:
|
|
778
|
+
if isinstance(p, QTensor):
|
|
779
|
+
data = parameters_data[p]
|
|
780
|
+
|
|
781
|
+
# needs to create a new WeightQBytesTensor with the cached data that is in the cpu device like the data (faketensor p is still in cuda)
|
|
782
|
+
q = WeightQBytesTensor.create(p.qtype, p.axis, p.size(), p.stride(), data[0], data[1], activation_qtype=p.activation_qtype, requires_grad=p.requires_grad )
|
|
783
|
+
#q = torch.nn.Parameter(q , requires_grad=False)
|
|
784
|
+
ref = weakref.getweakrefs(p)
|
|
785
|
+
if ref:
|
|
786
|
+
torch._C._swap_tensor_impl(p, q)
|
|
787
|
+
else:
|
|
788
|
+
torch.utils.swap_tensors(p, q)
|
|
789
|
+
|
|
790
|
+
# p._data = data[0]
|
|
791
|
+
# p._scale = data[1]
|
|
792
|
+
else:
|
|
793
|
+
p.data = parameters_data[p]
|
|
794
|
+
# cl.stop()
|
|
795
|
+
# print(f"unload time: {cl.format_time_gap()}")
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
def gpu_load(self, model_id):
|
|
799
|
+
model = self.models[model_id]
|
|
800
|
+
self.active_models.append(model)
|
|
801
|
+
self.active_models_ids.append(model_id)
|
|
802
|
+
|
|
803
|
+
self.gpu_load_blocks(model_id, None)
|
|
804
|
+
|
|
805
|
+
# torch.cuda.current_stream().synchronize()
|
|
806
|
+
|
|
807
|
+
def unload_all(self):
|
|
808
|
+
for model_id in self.active_models_ids:
|
|
809
|
+
self.gpu_unload_blocks(model_id, None)
|
|
810
|
+
loaded_block = self.loaded_blocks[model_id]
|
|
811
|
+
if loaded_block != None:
|
|
812
|
+
self.gpu_unload_blocks(model_id, loaded_block)
|
|
813
|
+
self.loaded_blocks[model_id] = None
|
|
814
|
+
|
|
815
|
+
self.active_models = []
|
|
816
|
+
self.active_models_ids = []
|
|
817
|
+
self.active_subcaches = []
|
|
818
|
+
torch.cuda.empty_cache()
|
|
819
|
+
gc.collect()
|
|
820
|
+
self.last_reserved_mem_check = time.time()
|
|
821
|
+
|
|
822
|
+
def move_args_to_gpu(self, *args, **kwargs):
|
|
823
|
+
new_args= []
|
|
824
|
+
new_kwargs={}
|
|
825
|
+
for arg in args:
|
|
826
|
+
if torch.is_tensor(arg):
|
|
827
|
+
if arg.dtype == torch.float32:
|
|
828
|
+
arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
|
|
829
|
+
else:
|
|
830
|
+
arg = arg.cuda(non_blocking=True)
|
|
831
|
+
new_args.append(arg)
|
|
832
|
+
|
|
833
|
+
for k in kwargs:
|
|
834
|
+
arg = kwargs[k]
|
|
835
|
+
if torch.is_tensor(arg):
|
|
836
|
+
if arg.dtype == torch.float32:
|
|
837
|
+
arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
|
|
838
|
+
else:
|
|
839
|
+
arg = arg.cuda(non_blocking=True)
|
|
840
|
+
new_kwargs[k]= arg
|
|
841
|
+
|
|
842
|
+
return new_args, new_kwargs
|
|
843
|
+
|
|
844
|
+
def ready_to_check_mem(self):
|
|
845
|
+
if self.anyCompiledModule:
|
|
846
|
+
return
|
|
847
|
+
cur_clock = time.time()
|
|
848
|
+
# can't check at each call if we can empty the cuda cache as quering the reserved memory value is a time consuming operation
|
|
849
|
+
if (cur_clock - self.last_reserved_mem_check)<0.200:
|
|
850
|
+
return False
|
|
851
|
+
self.last_reserved_mem_check = cur_clock
|
|
852
|
+
return True
|
|
853
|
+
|
|
854
|
+
|
|
855
|
+
def empty_cache_if_needed(self):
|
|
856
|
+
mem_reserved = torch.cuda.memory_reserved()
|
|
857
|
+
mem_threshold = 0.9*self.device_mem_capacity
|
|
858
|
+
if mem_reserved >= mem_threshold:
|
|
859
|
+
mem_allocated = torch.cuda.memory_allocated()
|
|
860
|
+
if mem_allocated <= 0.70 * mem_reserved:
|
|
861
|
+
# print(f"Cuda empty cache triggered as Allocated Memory ({mem_allocated/1024000:0f} MB) is lot less than Cached Memory ({mem_reserved/1024000:0f} MB) ")
|
|
862
|
+
torch.cuda.empty_cache()
|
|
863
|
+
tm= time.time()
|
|
864
|
+
if self.verboseLevel >=2:
|
|
865
|
+
print(f"Empty Cuda cache at {tm}")
|
|
866
|
+
# print(f"New cached memory after purge is {torch.cuda.memory_reserved()/1024000:0f} MB) ")
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
def any_param_or_buffer(self, target_module: torch.nn.Module):
|
|
870
|
+
|
|
871
|
+
for _ in target_module.parameters(recurse= False):
|
|
872
|
+
return True
|
|
873
|
+
|
|
874
|
+
for _ in target_module.buffers(recurse= False):
|
|
875
|
+
return True
|
|
876
|
+
|
|
877
|
+
return False
|
|
878
|
+
|
|
879
|
+
def hook_load_data_if_needed(self, target_module, model_id,blocks_name, context):
|
|
880
|
+
|
|
881
|
+
@torch.compiler.disable()
|
|
882
|
+
def load_data_if_needed(module, *args, **kwargs):
|
|
883
|
+
some_context = context #for debugging
|
|
884
|
+
if blocks_name == None:
|
|
885
|
+
if self.ready_to_check_mem():
|
|
886
|
+
self.empty_cache_if_needed()
|
|
887
|
+
else:
|
|
888
|
+
loaded_block = self.loaded_blocks[model_id]
|
|
889
|
+
if (loaded_block == None or loaded_block != blocks_name) :
|
|
890
|
+
if loaded_block != None:
|
|
891
|
+
self.gpu_unload_blocks(model_id, loaded_block)
|
|
892
|
+
if self.ready_to_check_mem():
|
|
893
|
+
self.empty_cache_if_needed()
|
|
894
|
+
self.loaded_blocks[model_id] = blocks_name
|
|
895
|
+
self.gpu_load_blocks(model_id, blocks_name)
|
|
896
|
+
|
|
897
|
+
target_module.register_forward_pre_hook(load_data_if_needed)
|
|
898
|
+
|
|
899
|
+
|
|
900
|
+
def hook_check_empty_cache_needed(self, target_module, model_id,blocks_name, previous_method, context):
|
|
901
|
+
|
|
902
|
+
def check_empty_cuda_cache(module, *args, **kwargs):
|
|
903
|
+
# if self.ready_to_check_mem():
|
|
904
|
+
# self.empty_cache_if_needed()
|
|
905
|
+
if blocks_name == None:
|
|
906
|
+
if self.ready_to_check_mem():
|
|
907
|
+
self.empty_cache_if_needed()
|
|
908
|
+
else:
|
|
909
|
+
loaded_block = self.loaded_blocks[model_id]
|
|
910
|
+
if (loaded_block == None or loaded_block != blocks_name) :
|
|
911
|
+
if loaded_block != None:
|
|
912
|
+
self.gpu_unload_blocks(model_id, loaded_block)
|
|
913
|
+
if self.ready_to_check_mem():
|
|
914
|
+
self.empty_cache_if_needed()
|
|
915
|
+
self.loaded_blocks[model_id] = blocks_name
|
|
916
|
+
self.gpu_load_blocks(model_id, blocks_name)
|
|
917
|
+
|
|
918
|
+
return previous_method(*args, **kwargs)
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
if hasattr(target_module, "_mm_id"):
|
|
922
|
+
orig_model_id = getattr(target_module, "_mm_id")
|
|
923
|
+
if self.verboseLevel >=2:
|
|
924
|
+
print(f"Model '{model_id}' shares module '{target_module._get_name()}' with module '{orig_model_id}' ")
|
|
925
|
+
assert not self.any_param_or_buffer(target_module)
|
|
926
|
+
|
|
927
|
+
return
|
|
928
|
+
setattr(target_module, "_mm_id", model_id)
|
|
929
|
+
setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method) )
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
def hook_change_module(self, target_module, model, model_id, module_id, previous_method):
|
|
933
|
+
def check_change_module(module, *args, **kwargs):
|
|
934
|
+
performEmptyCacheTest = False
|
|
935
|
+
if not model_id in self.active_models_ids:
|
|
936
|
+
new_model_id = getattr(module, "_mm_id")
|
|
937
|
+
# do not always unload existing models if it is more efficient to keep in them in the GPU
|
|
938
|
+
# (e.g: small modules whose calls are text encoders)
|
|
939
|
+
if not self.can_model_be_cotenant(new_model_id) :
|
|
940
|
+
self.unload_all()
|
|
941
|
+
performEmptyCacheTest = False
|
|
942
|
+
self.gpu_load(new_model_id)
|
|
943
|
+
# transfer leftovers inputs that were incorrectly created in the RAM (mostly due to some .device tests that returned incorrectly "cpu")
|
|
944
|
+
args, kwargs = self.move_args_to_gpu(*args, **kwargs)
|
|
945
|
+
if performEmptyCacheTest:
|
|
946
|
+
self.empty_cache_if_needed()
|
|
947
|
+
|
|
948
|
+
return previous_method(*args, **kwargs)
|
|
949
|
+
|
|
950
|
+
if hasattr(target_module, "_mm_id"):
|
|
951
|
+
return
|
|
952
|
+
setattr(target_module, "_mm_id", model_id)
|
|
953
|
+
|
|
954
|
+
setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_change_module, target_module), previous_method) )
|
|
955
|
+
|
|
956
|
+
if not self.verboseLevel >=1:
|
|
957
|
+
return
|
|
958
|
+
|
|
959
|
+
if module_id == None or module_id =='':
|
|
960
|
+
model_name = model._get_name()
|
|
961
|
+
print(f"Hooked in model '{model_id}' ({model_name})")
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
# Not implemented yet, but why would one want to get rid of these features ?
|
|
965
|
+
# def unhook_module(module: torch.nn.Module):
|
|
966
|
+
# if not hasattr(module,"_mm_id"):
|
|
967
|
+
# return
|
|
968
|
+
|
|
969
|
+
# delattr(module, "_mm_id")
|
|
970
|
+
|
|
971
|
+
# def unhook_all(parent_module: torch.nn.Module):
|
|
972
|
+
# for module in parent_module.components.items():
|
|
973
|
+
# self.unhook_module(module)
|
|
974
|
+
|
|
975
|
+
def fast_load_transformers_model(model_path: str, do_quantize = False, quantization_type = qint8, pinToMemory = False, partialPinning = False, verbose_level = 1):
|
|
976
|
+
"""
|
|
977
|
+
quick version of .LoadfromPretrained of the transformers library
|
|
978
|
+
used to build a model and load the corresponding weights (quantized or not)
|
|
979
|
+
"""
|
|
980
|
+
import os.path
|
|
981
|
+
from accelerate import init_empty_weights
|
|
982
|
+
|
|
983
|
+
if not (model_path.endswith(".sft") or model_path.endswith(".safetensors")):
|
|
984
|
+
raise Exception("full model path to file expected")
|
|
985
|
+
|
|
986
|
+
model_path = _get_model(model_path)
|
|
987
|
+
|
|
988
|
+
with safetensors2.safe_open(model_path) as f:
|
|
989
|
+
metadata = f.metadata()
|
|
990
|
+
|
|
991
|
+
if metadata is None:
|
|
992
|
+
transformer_config = None
|
|
993
|
+
else:
|
|
994
|
+
transformer_config = metadata.get("config", None)
|
|
995
|
+
|
|
996
|
+
if transformer_config == None:
|
|
997
|
+
config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
|
|
998
|
+
|
|
999
|
+
if not os.path.isfile(config_fullpath):
|
|
1000
|
+
raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
|
|
1001
|
+
|
|
1002
|
+
with open(config_fullpath, "r", encoding="utf-8") as reader:
|
|
1003
|
+
text = reader.read()
|
|
1004
|
+
transformer_config= json.loads(text)
|
|
1005
|
+
|
|
1006
|
+
|
|
1007
|
+
if "architectures" in transformer_config:
|
|
1008
|
+
architectures = transformer_config["architectures"]
|
|
1009
|
+
class_name = architectures[0]
|
|
1010
|
+
|
|
1011
|
+
module = __import__("transformers")
|
|
1012
|
+
transfomer_class = getattr(module, class_name)
|
|
1013
|
+
from transformers import AutoConfig
|
|
1014
|
+
|
|
1015
|
+
import tempfile
|
|
1016
|
+
with tempfile.NamedTemporaryFile("w", delete = False, encoding ="utf-8") as fp:
|
|
1017
|
+
fp.write(json.dumps(transformer_config))
|
|
1018
|
+
fp.close()
|
|
1019
|
+
config_obj = AutoConfig.from_pretrained(fp.name)
|
|
1020
|
+
os.remove(fp.name)
|
|
1021
|
+
|
|
1022
|
+
#needed to keep inits of non persistent buffers
|
|
1023
|
+
with init_empty_weights():
|
|
1024
|
+
model = transfomer_class(config_obj)
|
|
1025
|
+
|
|
1026
|
+
model = model.base_model
|
|
1027
|
+
|
|
1028
|
+
elif "_class_name" in transformer_config:
|
|
1029
|
+
class_name = transformer_config["_class_name"]
|
|
1030
|
+
|
|
1031
|
+
module = __import__("diffusers")
|
|
1032
|
+
transfomer_class = getattr(module, class_name)
|
|
1033
|
+
|
|
1034
|
+
with init_empty_weights():
|
|
1035
|
+
model = transfomer_class.from_config(transformer_config)
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
torch.set_default_device('cpu')
|
|
1039
|
+
|
|
1040
|
+
model._config = transformer_config
|
|
1041
|
+
|
|
1042
|
+
load_model_data(model,model_path, do_quantize = do_quantize, quantization_type = quantization_type, pinToMemory= pinToMemory, partialPinning= partialPinning, verboseLevel=verbose_level )
|
|
1043
|
+
|
|
1044
|
+
return model
|
|
1045
|
+
|
|
1046
|
+
|
|
1047
|
+
|
|
1048
|
+
def load_model_data(model, file_path: str, do_quantize = False, quantization_type = qint8, pinToMemory = False, partialPinning = False, verboseLevel = 1):
|
|
1049
|
+
"""
|
|
1050
|
+
Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
|
|
1051
|
+
"""
|
|
1052
|
+
|
|
1053
|
+
file_path = _get_model(file_path)
|
|
1054
|
+
safetensors2.verboseLevel = verboseLevel
|
|
1055
|
+
model = _remove_model_wrapper(model)
|
|
1056
|
+
|
|
1057
|
+
# if pinToMemory and do_quantize:
|
|
1058
|
+
# raise Exception("Pinning and Quantization can not be used at the same time")
|
|
1059
|
+
|
|
1060
|
+
if not (".safetensors" in file_path or ".sft" in file_path):
|
|
1061
|
+
if pinToMemory:
|
|
1062
|
+
raise Exception("Pinning to memory while loading only supported for safe tensors files")
|
|
1063
|
+
state_dict = torch.load(file_path, weights_only=True)
|
|
1064
|
+
if "module" in state_dict:
|
|
1065
|
+
state_dict = state_dict["module"]
|
|
1066
|
+
else:
|
|
1067
|
+
state_dict, metadata = _safetensors_load_file(file_path)
|
|
1068
|
+
|
|
1069
|
+
|
|
1070
|
+
# if pinToMemory:
|
|
1071
|
+
# _pin_to_memory_sd(model,state_dict, file_path, partialPinning = partialPinning, perc_reserved_mem_max = perc_reserved_mem_max, verboseLevel = verboseLevel)
|
|
1072
|
+
|
|
1073
|
+
# with safetensors2.safe_open(file_path) as f:
|
|
1074
|
+
# metadata = f.metadata()
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
if metadata is None:
|
|
1078
|
+
quantization_map = None
|
|
1079
|
+
else:
|
|
1080
|
+
quantization_map = metadata.get("quantization_map", None)
|
|
1081
|
+
config = metadata.get("config", None)
|
|
1082
|
+
if config is not None:
|
|
1083
|
+
model._config = config
|
|
1084
|
+
|
|
1085
|
+
|
|
1086
|
+
|
|
1087
|
+
if quantization_map is None:
|
|
1088
|
+
pos = str.rfind(file_path, ".")
|
|
1089
|
+
if pos > 0:
|
|
1090
|
+
quantization_map_path = file_path[:pos]
|
|
1091
|
+
quantization_map_path += "_map.json"
|
|
1092
|
+
|
|
1093
|
+
if os.path.isfile(quantization_map_path):
|
|
1094
|
+
with open(quantization_map_path, 'r') as f:
|
|
1095
|
+
quantization_map = json.load(f)
|
|
1096
|
+
|
|
1097
|
+
|
|
1098
|
+
|
|
1099
|
+
if quantization_map is None :
|
|
1100
|
+
if "quanto" in file_path and not do_quantize:
|
|
1101
|
+
print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
|
|
1102
|
+
else:
|
|
1103
|
+
_requantize(model, state_dict, quantization_map)
|
|
1104
|
+
|
|
1105
|
+
missing_keys , unexpected_keys = model.load_state_dict(state_dict, strict = quantization_map is None, assign = True )
|
|
1106
|
+
del state_dict
|
|
1107
|
+
|
|
1108
|
+
if do_quantize:
|
|
1109
|
+
if quantization_map is None:
|
|
1110
|
+
if _quantize(model, quantization_type, verboseLevel=verboseLevel, model_id=file_path):
|
|
1111
|
+
quantization_map = model._quanto_map
|
|
1112
|
+
else:
|
|
1113
|
+
if verboseLevel >=1:
|
|
1114
|
+
print("Model already quantized")
|
|
1115
|
+
|
|
1116
|
+
if pinToMemory:
|
|
1117
|
+
_pin_to_memory(model, file_path, partialPinning = partialPinning, verboseLevel = verboseLevel)
|
|
1118
|
+
|
|
1119
|
+
return
|
|
1120
|
+
|
|
1121
|
+
def save_model(model, file_path, do_quantize = False, quantization_type = qint8, verboseLevel = 1 ):
|
|
1122
|
+
"""save the weights of a model and quantize them if requested
|
|
1123
|
+
These weights can be loaded again using 'load_model_data'
|
|
1124
|
+
"""
|
|
1125
|
+
|
|
1126
|
+
config = None
|
|
1127
|
+
|
|
1128
|
+
if hasattr(model, "_config"):
|
|
1129
|
+
config = model._config
|
|
1130
|
+
elif hasattr(model, "config"):
|
|
1131
|
+
config_fullpath = None
|
|
1132
|
+
config_obj = getattr(model,"config")
|
|
1133
|
+
config_path = getattr(config_obj,"_name_or_path", None)
|
|
1134
|
+
if config_path != None:
|
|
1135
|
+
config_fullpath = os.path.join(config_path, "config.json")
|
|
1136
|
+
if not os.path.isfile(config_fullpath):
|
|
1137
|
+
config_fullpath = None
|
|
1138
|
+
if config_fullpath is None:
|
|
1139
|
+
config_fullpath = os.path.join(os.path.dirname(file_path), "config.json")
|
|
1140
|
+
if os.path.isfile(config_fullpath):
|
|
1141
|
+
with open(config_fullpath, "r", encoding="utf-8") as reader:
|
|
1142
|
+
text = reader.read()
|
|
1143
|
+
config= json.loads(text)
|
|
1144
|
+
|
|
1145
|
+
if do_quantize:
|
|
1146
|
+
_quantize(model, weights=quantization_type, model_id=file_path)
|
|
1147
|
+
|
|
1148
|
+
quantization_map = getattr(model, "_quanto_map", None)
|
|
1149
|
+
|
|
1150
|
+
if verboseLevel >=1:
|
|
1151
|
+
print(f"Saving file '{file_path}")
|
|
1152
|
+
safetensors2.torch_write_file(model.state_dict(), file_path , quantization_map = quantization_map, config = config)
|
|
1153
|
+
if verboseLevel >=1:
|
|
1154
|
+
print(f"File '{file_path} saved")
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
|
|
1158
|
+
|
|
1159
|
+
def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = True, extraModelsToQuantize = None, budgets= 0, asyncTransfers = True, compile = False, perc_reserved_mem_max = 0, verboseLevel = 1):
|
|
1160
|
+
"""Hook to a pipeline or a group of modules in order to reduce their VRAM requirements:
|
|
1161
|
+
pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
|
|
1162
|
+
quantizeTransformer: set True by default will quantize on the fly the video / image model
|
|
1163
|
+
pinnedMemory: move models in reserved memor. This allows very fast performance but requires 50% extra RAM (usually >=64 GB)
|
|
1164
|
+
extraModelsToQuantize: a list of models to be also quantized on the fly (e.g the text_encoder), useful to reduce bith RAM and VRAM consumption
|
|
1165
|
+
budgets: 0 by default (unlimited). If non 0, it corresponds to the maximum size in MB that every model will occupy at any moment
|
|
1166
|
+
(in fact the real usage is twice this number). It is very efficient to reduce VRAM consumption but this feature may be very slow
|
|
1167
|
+
if pinnedMemory is not enabled
|
|
1168
|
+
"""
|
|
1169
|
+
self = offload()
|
|
1170
|
+
self.verboseLevel = verboseLevel
|
|
1171
|
+
safetensors2.verboseLevel = verboseLevel
|
|
1172
|
+
self.modules_data = {}
|
|
1173
|
+
model_budgets = {}
|
|
1174
|
+
|
|
1175
|
+
windows_os = os.name == 'nt'
|
|
1176
|
+
global total_pinned_bytes
|
|
1177
|
+
|
|
1178
|
+
|
|
1179
|
+
budget = 0
|
|
1180
|
+
if not budgets is None:
|
|
1181
|
+
if isinstance(budgets , dict):
|
|
1182
|
+
model_budgets = budgets
|
|
1183
|
+
else:
|
|
1184
|
+
budget = int(budgets) * ONE_MB
|
|
1185
|
+
|
|
1186
|
+
# if (budgets!= None or budget >0) :
|
|
1187
|
+
# self.async_transfers = True
|
|
1188
|
+
self.async_transfers = asyncTransfers
|
|
1189
|
+
|
|
1190
|
+
|
|
1191
|
+
|
|
1192
|
+
torch.set_default_device('cpu')
|
|
1193
|
+
|
|
1194
|
+
if hasattr(pipe_or_dict_of_modules, "components"):
|
|
1195
|
+
# create a fake Accelerate parameter so that lora loading doesn't change the device
|
|
1196
|
+
pipe_or_dict_of_modules.hf_device_map = torch.device("cuda")
|
|
1197
|
+
pipe_or_dict_of_modules= pipe_or_dict_of_modules.components
|
|
1198
|
+
|
|
1199
|
+
|
|
1200
|
+
models = {k: _remove_model_wrapper(v) for k, v in pipe_or_dict_of_modules.items() if isinstance(v, torch.nn.Module)}
|
|
1201
|
+
|
|
1202
|
+
|
|
1203
|
+
|
|
1204
|
+
_welcome()
|
|
1205
|
+
|
|
1206
|
+
self.models = models
|
|
1207
|
+
|
|
1208
|
+
extraModelsToQuantize = extraModelsToQuantize if extraModelsToQuantize is not None else []
|
|
1209
|
+
if not isinstance(extraModelsToQuantize, list):
|
|
1210
|
+
extraModelsToQuantize= [extraModelsToQuantize]
|
|
1211
|
+
if quantizeTransformer:
|
|
1212
|
+
extraModelsToQuantize.append("transformer")
|
|
1213
|
+
models_to_quantize = extraModelsToQuantize
|
|
1214
|
+
|
|
1215
|
+
modelsToPin = []
|
|
1216
|
+
pinAllModels = False
|
|
1217
|
+
if isinstance(pinnedMemory, bool):
|
|
1218
|
+
pinAllModels = pinnedMemory
|
|
1219
|
+
elif isinstance(pinnedMemory, list):
|
|
1220
|
+
modelsToPin = pinnedMemory
|
|
1221
|
+
else:
|
|
1222
|
+
modelsToPin = [pinnedMemory]
|
|
1223
|
+
|
|
1224
|
+
modelsToCompile = []
|
|
1225
|
+
compileAllModels = False
|
|
1226
|
+
if isinstance(compile, bool):
|
|
1227
|
+
compileAllModels = compile
|
|
1228
|
+
elif isinstance(compile, list):
|
|
1229
|
+
modelsToCompile = compile
|
|
1230
|
+
else:
|
|
1231
|
+
modelsToCompile = [compile]
|
|
1232
|
+
|
|
1233
|
+
self.anyCompiledModule = compileAllModels or len(modelsToCompile)>0
|
|
1234
|
+
if self.anyCompiledModule:
|
|
1235
|
+
torch._dynamo.config.cache_size_limit = 10000
|
|
1236
|
+
|
|
1237
|
+
max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
|
|
1238
|
+
|
|
1239
|
+
estimatesBytesToPin = 0
|
|
1240
|
+
|
|
1241
|
+
for model_id in models:
|
|
1242
|
+
current_model: torch.nn.Module = models[model_id]
|
|
1243
|
+
# make sure that no RAM or GPU memory is not allocated for gradiant / training
|
|
1244
|
+
current_model.to("cpu").eval()
|
|
1245
|
+
|
|
1246
|
+
# if the model has just been quantized so there is no need to quantize it again
|
|
1247
|
+
if model_id in models_to_quantize:
|
|
1248
|
+
_quantize(current_model, weights=qint8, verboseLevel = self.verboseLevel, model_id=model_id)
|
|
1249
|
+
|
|
1250
|
+
modelPinned = (pinAllModels or model_id in modelsToPin) and not hasattr(current_model,"_already_pinned")
|
|
1251
|
+
|
|
1252
|
+
current_model_size = 0
|
|
1253
|
+
# load all the remaining unread lazy safetensors in RAM to free open cache files
|
|
1254
|
+
for p in current_model.parameters():
|
|
1255
|
+
if isinstance(p, QTensor):
|
|
1256
|
+
# # fix quanto bug (seems to have been fixed)
|
|
1257
|
+
# if not modelPinned and p._scale.dtype == torch.float32:
|
|
1258
|
+
# p._scale = p._scale.to(torch.bfloat16)
|
|
1259
|
+
current_model_size += torch.numel(p._scale) * p._scale.element_size()
|
|
1260
|
+
current_model_size += torch.numel(p._data) * p._data.element_size()
|
|
1261
|
+
else:
|
|
1262
|
+
if p.data.dtype == torch.float32:
|
|
1263
|
+
# convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
|
|
1264
|
+
p.data = p.data.to(torch.bfloat16)
|
|
1265
|
+
current_model_size += torch.numel(p.data) * p.data.element_size()
|
|
1266
|
+
|
|
1267
|
+
for b in current_model.buffers():
|
|
1268
|
+
if b.data.dtype == torch.float32:
|
|
1269
|
+
# convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
|
|
1270
|
+
b.data = b.data.to(torch.bfloat16)
|
|
1271
|
+
current_model_size += torch.numel(b.data) * b.data.element_size()
|
|
1272
|
+
|
|
1273
|
+
if modelPinned:
|
|
1274
|
+
estimatesBytesToPin += current_model_size
|
|
1275
|
+
|
|
1276
|
+
|
|
1277
|
+
model_budget = model_budgets[model_id] * ONE_MB if model_id in model_budgets else budget
|
|
1278
|
+
|
|
1279
|
+
if model_budget > 0 and model_budget > current_model_size:
|
|
1280
|
+
model_budget = 0
|
|
1281
|
+
|
|
1282
|
+
model_budgets[model_id] = model_budget
|
|
1283
|
+
|
|
1284
|
+
partialPinning = False
|
|
1285
|
+
|
|
1286
|
+
if estimatesBytesToPin > 0 and estimatesBytesToPin >= (max_reservable_memory - total_pinned_bytes):
|
|
1287
|
+
if self.verboseLevel >=1:
|
|
1288
|
+
print(f"Switching to partial pinning since full requirements for pinned models is {estimatesBytesToPin/ONE_MB:0.1f} MB while estimated reservable RAM is {max_reservable_memory/ONE_MB:0.1f} MB" )
|
|
1289
|
+
partialPinning = True
|
|
1290
|
+
|
|
1291
|
+
# Hook forward methods of modules
|
|
1292
|
+
for model_id in models:
|
|
1293
|
+
current_model: torch.nn.Module = models[model_id]
|
|
1294
|
+
current_budget = model_budgets[model_id]
|
|
1295
|
+
current_size = 0
|
|
1296
|
+
cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
|
|
1297
|
+
self.loaded_blocks[model_id] = None
|
|
1298
|
+
towers_names, towers_modules = _detect_main_towers(current_model)
|
|
1299
|
+
towers_names = [n +"." for n in towers_names]
|
|
1300
|
+
if self.verboseLevel>=2 and len(towers_names)>0:
|
|
1301
|
+
print(f"Potential iterative blocks found in model '{model_id}':{towers_names}")
|
|
1302
|
+
# compile main iterative modules stacks ("towers")
|
|
1303
|
+
if compileAllModels or model_id in modelsToCompile :
|
|
1304
|
+
#torch.compiler.reset()
|
|
1305
|
+
if self.verboseLevel>=1:
|
|
1306
|
+
print(f"Pytorch compilation of model '{model_id}' is scheduled.")
|
|
1307
|
+
for tower in towers_modules:
|
|
1308
|
+
for submodel in tower:
|
|
1309
|
+
submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode="default" ) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
|
|
1310
|
+
|
|
1311
|
+
|
|
1312
|
+
for submodule_name, submodule in current_model.named_modules():
|
|
1313
|
+
# create a fake 'accelerate' parameter so that the _execution_device property returns always "cuda"
|
|
1314
|
+
# (it is queried in many pipelines even if offloading is not properly implemented)
|
|
1315
|
+
if not hasattr(submodule, "_hf_hook"):
|
|
1316
|
+
setattr(submodule, "_hf_hook", HfHook())
|
|
1317
|
+
|
|
1318
|
+
if submodule_name=='':
|
|
1319
|
+
continue
|
|
1320
|
+
newListItem = False
|
|
1321
|
+
if current_budget > 0:
|
|
1322
|
+
if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): #
|
|
1323
|
+
if cur_blocks_prefix == None:
|
|
1324
|
+
cur_blocks_prefix = submodule_name + "."
|
|
1325
|
+
else:
|
|
1326
|
+
#if cur_blocks_prefix != submodule_name[:len(cur_blocks_prefix)]:
|
|
1327
|
+
if not submodule_name.startswith(cur_blocks_prefix):
|
|
1328
|
+
cur_blocks_prefix = submodule_name + "."
|
|
1329
|
+
cur_blocks_name,cur_blocks_seq = None, -1
|
|
1330
|
+
else:
|
|
1331
|
+
|
|
1332
|
+
if cur_blocks_prefix is not None:
|
|
1333
|
+
if submodule_name.startswith(cur_blocks_prefix):
|
|
1334
|
+
num = int(submodule_name[len(cur_blocks_prefix):].split(".")[0])
|
|
1335
|
+
newListItem= num != cur_blocks_seq
|
|
1336
|
+
if num != cur_blocks_seq and (cur_blocks_name == None or current_size > current_budget):
|
|
1337
|
+
prev_blocks_name = cur_blocks_name
|
|
1338
|
+
cur_blocks_name = cur_blocks_prefix + str(num)
|
|
1339
|
+
# print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
|
|
1340
|
+
cur_blocks_seq = num
|
|
1341
|
+
else:
|
|
1342
|
+
cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
|
|
1343
|
+
|
|
1344
|
+
if hasattr(submodule, "forward"):
|
|
1345
|
+
submodule_method = getattr(submodule, "forward")
|
|
1346
|
+
if callable(submodule_method):
|
|
1347
|
+
if len(submodule_name.split("."))==1:
|
|
1348
|
+
self.hook_change_module(submodule, current_model, model_id, submodule_name, submodule_method)
|
|
1349
|
+
elif newListItem:
|
|
1350
|
+
self.hook_load_data_if_needed(submodule, model_id, cur_blocks_name, context = submodule_name )
|
|
1351
|
+
else:
|
|
1352
|
+
self.hook_check_empty_cache_needed(submodule, model_id, cur_blocks_name, submodule_method, context = submodule_name )
|
|
1353
|
+
|
|
1354
|
+
|
|
1355
|
+
current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
|
|
1356
|
+
|
|
1357
|
+
|
|
1358
|
+
parameters_data = {}
|
|
1359
|
+
if pinAllModels or model_id in modelsToPin:
|
|
1360
|
+
if hasattr(current_model,"_already_pinned"):
|
|
1361
|
+
if self.verboseLevel >=1:
|
|
1362
|
+
print(f"Model '{model_id}' already pinned to reserved memory")
|
|
1363
|
+
else:
|
|
1364
|
+
_pin_to_memory(current_model, model_id, partialPinning= partialPinning, perc_reserved_mem_max=perc_reserved_mem_max, verboseLevel=verboseLevel)
|
|
1365
|
+
|
|
1366
|
+
|
|
1367
|
+
for p in current_model.parameters():
|
|
1368
|
+
parameters_data[p] = [p._data, p._scale] if isinstance(p, QTensor) else p.data
|
|
1369
|
+
|
|
1370
|
+
buffers_data = {b: b.data for b in current_model.buffers()}
|
|
1371
|
+
parameters_data.update(buffers_data)
|
|
1372
|
+
self.modules_data[model_id]=parameters_data
|
|
1373
|
+
|
|
1374
|
+
if self.verboseLevel >=2:
|
|
1375
|
+
for n,b in self.blocks_of_modules_sizes.items():
|
|
1376
|
+
print(f"Size of submodel '{n}': {b/ONE_MB:.1f} MB")
|
|
1377
|
+
|
|
1378
|
+
torch.set_default_device('cuda')
|
|
1379
|
+
torch.cuda.empty_cache()
|
|
1380
|
+
gc.collect()
|
|
1381
|
+
|
|
1382
|
+
return self
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.VerylowRAM_LowVRAM_Slowest , verboseLevel = 1, **overrideKwargs):
|
|
1386
|
+
"""Apply a configuration profile that depends on your hardware:
|
|
1387
|
+
pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
|
|
1388
|
+
profile_name : num of the profile:
|
|
1389
|
+
HighRAM_HighVRAM_Fastest (=1): at least 48 GB of RAM and 24 GB of VRAM : the fastest well suited for a RTX 3090 / RTX 4090
|
|
1390
|
+
HighRAM_LowVRAM_Fast (=2): at least 48 GB of RAM and 12 GB of VRAM : a bit slower, better suited for RTX 3070/3080/4070/4080
|
|
1391
|
+
or for RTX 3090 / RTX 4090 with large pictures batches or long videos
|
|
1392
|
+
LowRAM_HighVRAM_Medium (=3): at least 32 GB of RAM and 24 GB of VRAM : so so speed but adapted for RTX 3090 / RTX 4090 with limited RAM
|
|
1393
|
+
LowRAM_LowVRAM_Slow (=4): at least 32 GB of RAM and 12 GB of VRAM : if have little VRAM or generate longer videos
|
|
1394
|
+
VerylowRAM_LowVRAM_Slowest (=5): at least 24 GB of RAM and 10 GB of VRAM : if you don't have much it won't be fast but maybe it will work
|
|
1395
|
+
overrideKwargs: every parameter accepted by Offload.All can be added here to override the profile choice
|
|
1396
|
+
For instance set quantizeTransformer = False to disable transformer quantization which is by default in every profile
|
|
1397
|
+
"""
|
|
1398
|
+
|
|
1399
|
+
_welcome()
|
|
1400
|
+
|
|
1401
|
+
modules = pipe_or_dict_of_modules
|
|
1402
|
+
|
|
1403
|
+
if hasattr(modules, "components"):
|
|
1404
|
+
modules= modules.components
|
|
1405
|
+
|
|
1406
|
+
modules = {k: _remove_model_wrapper(v) for k, v in modules.items() if isinstance(v, torch.nn.Module)}
|
|
1407
|
+
module_names = {k: v.__module__.lower() for k, v in modules.items() }
|
|
1408
|
+
|
|
1409
|
+
default_extraModelsToQuantize = []
|
|
1410
|
+
quantizeTransformer = True
|
|
1411
|
+
|
|
1412
|
+
models_to_scan = ("text_encoder", "text_encoder_2")
|
|
1413
|
+
candidates_to_quantize = ("t5", "llama", "llm")
|
|
1414
|
+
for model_id in models_to_scan:
|
|
1415
|
+
name = module_names[model_id]
|
|
1416
|
+
for candidate in candidates_to_quantize:
|
|
1417
|
+
if candidate in name:
|
|
1418
|
+
default_extraModelsToQuantize.append(model_id)
|
|
1419
|
+
break
|
|
1420
|
+
|
|
1421
|
+
|
|
1422
|
+
# transformer (video or image generator) should be as small as possible not to occupy space that could be used by actual image data
|
|
1423
|
+
# on the other hand the text encoder should be quite large (as long as it fits in 10 GB of VRAM) to reduce sequence offloading
|
|
1424
|
+
|
|
1425
|
+
default_budgets = { "transformer" : 600 , "text_encoder": 3000, "text_encoder_2": 3000 }
|
|
1426
|
+
extraModelsToQuantize = None
|
|
1427
|
+
|
|
1428
|
+
if profile_no == profile_type.HighRAM_HighVRAM_Fastest:
|
|
1429
|
+
pinnedMemory= True
|
|
1430
|
+
budgets = None
|
|
1431
|
+
info = "You have chosen a Very Fast profile that requires at least 48 GB of RAM and 24 GB of VRAM."
|
|
1432
|
+
elif profile_no == profile_type.HighRAM_LowVRAM_Fast:
|
|
1433
|
+
pinnedMemory= True
|
|
1434
|
+
budgets = default_budgets
|
|
1435
|
+
info = "You have chosen a Fast profile that requires at least 48 GB of RAM and 12 GB of VRAM."
|
|
1436
|
+
elif profile_no == profile_type.LowRAM_HighVRAM_Medium:
|
|
1437
|
+
pinnedMemory= "transformer"
|
|
1438
|
+
extraModelsToQuantize = default_extraModelsToQuantize
|
|
1439
|
+
info = "You have chosen a Medium speed profile that requires at least 32 GB of RAM and 24 GB of VRAM."
|
|
1440
|
+
elif profile_no == profile_type.LowRAM_LowVRAM_Slow:
|
|
1441
|
+
pinnedMemory= "transformer"
|
|
1442
|
+
extraModelsToQuantize = default_extraModelsToQuantize
|
|
1443
|
+
budgets=default_budgets
|
|
1444
|
+
asyncTransfers = True
|
|
1445
|
+
info = "You have chosen the Slow profile that requires at least 32 GB of RAM and 12 GB of VRAM."
|
|
1446
|
+
elif profile_no == profile_type.VerylowRAM_LowVRAM_Slowest:
|
|
1447
|
+
pinnedMemory= False
|
|
1448
|
+
extraModelsToQuantize = default_extraModelsToQuantize
|
|
1449
|
+
budgets=default_budgets
|
|
1450
|
+
budgets["transformer"] = 400
|
|
1451
|
+
asyncTransfers = False
|
|
1452
|
+
info = "You have chosen the Slowest profile that requires at least 24 GB of RAM and 10 GB of VRAM."
|
|
1453
|
+
else:
|
|
1454
|
+
raise Exception("Unknown profile")
|
|
1455
|
+
CrLf = '\r\n'
|
|
1456
|
+
kwargs = { "pinnedMemory": pinnedMemory, "extraModelsToQuantize" : extraModelsToQuantize, "budgets": budgets, "asyncTransfers" : asyncTransfers, "quantizeTransformer": quantizeTransformer }
|
|
1457
|
+
|
|
1458
|
+
if verboseLevel>=2:
|
|
1459
|
+
info = info + CrLf + f"Profile '{profile_type.tostr(profile_no)}' sets the following options:"
|
|
1460
|
+
for k,v in kwargs.items():
|
|
1461
|
+
if k in overrideKwargs:
|
|
1462
|
+
info = info + CrLf + f"- '{k}': '{kwargs[k]}' overriden with value '{overrideKwargs[k]}'"
|
|
1463
|
+
else:
|
|
1464
|
+
info = info + CrLf + f"- '{k}': '{kwargs[k]}'"
|
|
1465
|
+
|
|
1466
|
+
for k,v in overrideKwargs.items():
|
|
1467
|
+
kwargs[k] = overrideKwargs[k]
|
|
1468
|
+
|
|
1469
|
+
if info:
|
|
1470
|
+
print(info)
|
|
1471
|
+
|
|
1472
|
+
return all(pipe_or_dict_of_modules, verboseLevel = verboseLevel, **kwargs)
|