mmgp 3.0.3__py3-none-any.whl → 3.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mmgp might be problematic. Click here for more details.
- mmgp/offload.py +503 -394
- mmgp/safetensors2.py +85 -32
- {mmgp-3.0.3.dist-info → mmgp-3.1.0.dist-info}/METADATA +14 -10
- mmgp-3.1.0.dist-info/RECORD +9 -0
- {mmgp-3.0.3.dist-info → mmgp-3.1.0.dist-info}/WHEEL +1 -1
- mmgp-3.0.3.dist-info/RECORD +0 -9
- {mmgp-3.0.3.dist-info → mmgp-3.1.0.dist-info}/LICENSE.md +0 -0
- {mmgp-3.0.3.dist-info → mmgp-3.1.0.dist-info}/top_level.txt +0 -0
mmgp/offload.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# ------------------ Memory Management 3.
|
|
1
|
+
# ------------------ Memory Management 3.1 for the GPU Poor by DeepBeepMeep (mmgp)------------------
|
|
2
2
|
#
|
|
3
3
|
# This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
|
|
4
4
|
# This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
|
|
@@ -61,13 +61,25 @@ import sys
|
|
|
61
61
|
import os
|
|
62
62
|
import json
|
|
63
63
|
import psutil
|
|
64
|
+
try:
|
|
65
|
+
from diffusers.utils.peft_utils import set_weights_and_activate_adapters, get_peft_kwargs
|
|
66
|
+
except:
|
|
67
|
+
set_weights_and_activate_adapters = None
|
|
68
|
+
get_peft_kwargs = None
|
|
69
|
+
pass
|
|
70
|
+
try:
|
|
71
|
+
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
|
72
|
+
except:
|
|
73
|
+
inject_adapter_in_model = None
|
|
74
|
+
pass
|
|
75
|
+
|
|
64
76
|
from mmgp import safetensors2
|
|
65
77
|
from mmgp import profile_type
|
|
66
78
|
|
|
67
|
-
from optimum.quanto import freeze,
|
|
68
|
-
|
|
79
|
+
from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QTensor, quantize_module
|
|
69
80
|
|
|
70
81
|
|
|
82
|
+
shared_state = {}
|
|
71
83
|
|
|
72
84
|
mmm = safetensors2.mmm
|
|
73
85
|
|
|
@@ -127,6 +139,9 @@ def move_tensors(obj, device):
|
|
|
127
139
|
return _list
|
|
128
140
|
else:
|
|
129
141
|
raise TypeError("Tensor or list / dict of tensors expected")
|
|
142
|
+
def _get_module_name(v):
|
|
143
|
+
return v.__module__.lower()
|
|
144
|
+
|
|
130
145
|
|
|
131
146
|
def _compute_verbose_level(level):
|
|
132
147
|
if level <0:
|
|
@@ -139,33 +154,75 @@ def _get_max_reservable_memory(perc_reserved_mem_max):
|
|
|
139
154
|
perc_reserved_mem_max = 0.40 if os.name == 'nt' else 0.5
|
|
140
155
|
return perc_reserved_mem_max * physical_memory
|
|
141
156
|
|
|
142
|
-
def _detect_main_towers(model, verboseLevel=1):
|
|
157
|
+
def _detect_main_towers(model, min_floors = 5, verboseLevel=1):
|
|
143
158
|
cur_blocks_prefix = None
|
|
144
159
|
towers_modules= []
|
|
145
160
|
towers_names= []
|
|
146
161
|
|
|
162
|
+
floors_modules= []
|
|
163
|
+
tower_name = None
|
|
164
|
+
|
|
165
|
+
|
|
147
166
|
for submodule_name, submodule in model.named_modules():
|
|
167
|
+
|
|
148
168
|
if submodule_name=='':
|
|
149
169
|
continue
|
|
150
170
|
|
|
151
|
-
if
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
171
|
+
if cur_blocks_prefix != None:
|
|
172
|
+
if submodule_name.startswith(cur_blocks_prefix):
|
|
173
|
+
depth_prefix = cur_blocks_prefix.split(".")
|
|
174
|
+
depth_name = submodule_name.split(".")
|
|
175
|
+
level = depth_name[len(depth_prefix)-1]
|
|
176
|
+
pre , num = _extract_num_from_str(level)
|
|
177
|
+
|
|
178
|
+
if num != cur_blocks_seq:
|
|
179
|
+
floors_modules.append(submodule)
|
|
160
180
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
181
|
+
cur_blocks_seq = num
|
|
182
|
+
else:
|
|
183
|
+
if len(floors_modules) >= min_floors:
|
|
184
|
+
towers_modules += floors_modules
|
|
185
|
+
towers_names.append(tower_name)
|
|
186
|
+
tower_name = None
|
|
187
|
+
floors_modules= []
|
|
188
|
+
cur_blocks_prefix, cur_blocks_seq = None, -1
|
|
189
|
+
|
|
190
|
+
if cur_blocks_prefix == None:
|
|
191
|
+
pre , num = _extract_num_from_str(submodule_name)
|
|
192
|
+
if isinstance(submodule, (torch.nn.ModuleList)):
|
|
193
|
+
cur_blocks_prefix, cur_blocks_seq = pre + ".", -1
|
|
194
|
+
tower_name = submodule_name + ".*"
|
|
195
|
+
elif num >=0:
|
|
196
|
+
cur_blocks_prefix, cur_blocks_seq = pre, num
|
|
197
|
+
tower_name = submodule_name[ :-1] + "*"
|
|
198
|
+
floors_modules.append(submodule)
|
|
199
|
+
|
|
200
|
+
if len(floors_modules) >= min_floors:
|
|
201
|
+
towers_modules += floors_modules
|
|
202
|
+
towers_names.append(tower_name)
|
|
203
|
+
|
|
204
|
+
# for submodule_name, submodule in model.named_modules():
|
|
205
|
+
# if submodule_name=='':
|
|
206
|
+
# continue
|
|
207
|
+
|
|
208
|
+
# if isinstance(submodule, torch.nn.ModuleList):
|
|
209
|
+
# newList =False
|
|
210
|
+
# if cur_blocks_prefix == None:
|
|
211
|
+
# cur_blocks_prefix = submodule_name + "."
|
|
212
|
+
# newList = True
|
|
213
|
+
# else:
|
|
214
|
+
# if not submodule_name.startswith(cur_blocks_prefix):
|
|
215
|
+
# cur_blocks_prefix = submodule_name + "."
|
|
216
|
+
# newList = True
|
|
217
|
+
|
|
218
|
+
# if newList and len(submodule)>=5:
|
|
219
|
+
# towers_names.append(submodule_name)
|
|
220
|
+
# towers_modules.append(submodule)
|
|
164
221
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
222
|
+
# else:
|
|
223
|
+
# if cur_blocks_prefix is not None:
|
|
224
|
+
# if not submodule_name.startswith(cur_blocks_prefix):
|
|
225
|
+
# cur_blocks_prefix = None
|
|
169
226
|
|
|
170
227
|
return towers_names, towers_modules
|
|
171
228
|
|
|
@@ -179,7 +236,7 @@ def _get_model(model_path):
|
|
|
179
236
|
_path = Path(model_path).parts
|
|
180
237
|
_filename = _path[-1]
|
|
181
238
|
_path = _path[:-1]
|
|
182
|
-
if len(_path)
|
|
239
|
+
if len(_path)<=1:
|
|
183
240
|
raise("file not found")
|
|
184
241
|
else:
|
|
185
242
|
from huggingface_hub import hf_hub_download #snapshot_download,
|
|
@@ -263,7 +320,13 @@ def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_ma
|
|
|
263
320
|
# print(f"num params to pin {model_id}: {len(params_list)}")
|
|
264
321
|
for p in params_list:
|
|
265
322
|
if isinstance(p, QTensor):
|
|
266
|
-
|
|
323
|
+
if p._qtype == qint4:
|
|
324
|
+
if hasattr(p,"_scale_shift"):
|
|
325
|
+
length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale_shift) * p._scale_shift.element_size()
|
|
326
|
+
else:
|
|
327
|
+
length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale) * p._scale.element_size() + torch.numel(p._shift) * p._shift.element_size()
|
|
328
|
+
else:
|
|
329
|
+
length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
|
|
267
330
|
else:
|
|
268
331
|
length = torch.numel(p.data) * p.data.element_size()
|
|
269
332
|
|
|
@@ -306,10 +369,22 @@ def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_ma
|
|
|
306
369
|
if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
|
|
307
370
|
current_big_tensor = big_tensors[big_tensor_no]
|
|
308
371
|
if isinstance(p, QTensor):
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
372
|
+
if p._qtype == qint4:
|
|
373
|
+
length1 = torch.numel(p._data._data) * p._data._data.element_size()
|
|
374
|
+
p._data._data = _move_to_pinned_tensor(p._data._data, current_big_tensor, offset, length1)
|
|
375
|
+
if hasattr(p,"_scale_shift"):
|
|
376
|
+
length2 = torch.numel(p._scale_shift) * p._scale_shift.element_size()
|
|
377
|
+
p._scale_shift = _move_to_pinned_tensor(p._scale_shift, current_big_tensor, offset + length1, length2)
|
|
378
|
+
else:
|
|
379
|
+
length2 = torch.numel(p._scale) * p._scale.element_size()
|
|
380
|
+
p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
|
|
381
|
+
length3 = torch.numel(p._shift) * p._shift.element_size()
|
|
382
|
+
p._shift = _move_to_pinned_tensor(p._shift, current_big_tensor, offset + length1 + length2, length3)
|
|
383
|
+
else:
|
|
384
|
+
length1 = torch.numel(p._data) * p._data.element_size()
|
|
385
|
+
p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
|
|
386
|
+
length2 = torch.numel(p._scale) * p._scale.element_size()
|
|
387
|
+
p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
|
|
313
388
|
else:
|
|
314
389
|
length = torch.numel(p.data) * p.data.element_size()
|
|
315
390
|
p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
|
|
@@ -336,100 +411,16 @@ def _welcome():
|
|
|
336
411
|
if welcome_displayed:
|
|
337
412
|
return
|
|
338
413
|
welcome_displayed = True
|
|
339
|
-
print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
# def _pin_to_memory_sd(model, sd, model_id, partialPinning = False, perc_reserved_mem_max = 0, verboseLevel = 1):
|
|
343
|
-
# if verboseLevel>=1 :
|
|
344
|
-
# if partialPinning:
|
|
345
|
-
# print(f"Partial pinning to reserved RAM of data of file '{model_id}' while loading it")
|
|
346
|
-
# else:
|
|
347
|
-
# print(f"Pinning data to reserved RAM of file '{model_id}' while loading it")
|
|
348
|
-
|
|
349
|
-
# max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
|
|
350
|
-
# if partialPinning:
|
|
351
|
-
# towers_names, _ = _detect_main_towers(model)
|
|
352
|
-
# towers_names = [n +"." for n in towers_names]
|
|
353
|
-
|
|
354
|
-
# BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
|
|
355
|
-
# current_big_tensor_size = 0
|
|
356
|
-
# big_tensor_no = 0
|
|
357
|
-
# big_tensors_sizes = []
|
|
358
|
-
# tensor_map_indexes = []
|
|
359
|
-
# total_tensor_bytes = 0
|
|
360
|
-
|
|
361
|
-
# for k,t in sd.items():
|
|
362
|
-
# include = True
|
|
363
|
-
# # if isinstance(p, QTensor):
|
|
364
|
-
# # length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
|
|
365
|
-
# # else:
|
|
366
|
-
# # length = torch.numel(p.data) * p.data.element_size()
|
|
367
|
-
# length = torch.numel(t) * t.data.element_size()
|
|
368
|
-
|
|
369
|
-
# if partialPinning:
|
|
370
|
-
# include = any(k.startswith(pre) for pre in towers_names) if partialPinning else True
|
|
371
|
-
|
|
372
|
-
# if include:
|
|
373
|
-
# if current_big_tensor_size + length > BIG_TENSOR_MAX_SIZE:
|
|
374
|
-
# big_tensors_sizes.append(current_big_tensor_size)
|
|
375
|
-
# current_big_tensor_size = 0
|
|
376
|
-
# big_tensor_no += 1
|
|
377
|
-
# tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
|
|
378
|
-
# current_big_tensor_size += length
|
|
379
|
-
# else:
|
|
380
|
-
# tensor_map_indexes.append((-1, 0, 0 ))
|
|
381
|
-
# total_tensor_bytes += length
|
|
382
|
-
|
|
383
|
-
# big_tensors_sizes.append(current_big_tensor_size)
|
|
384
|
-
|
|
385
|
-
# big_tensors = []
|
|
386
|
-
# last_big_tensor = 0
|
|
387
|
-
# total = 0
|
|
388
|
-
|
|
414
|
+
print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.1) by DeepBeepMeep ************{ENDC}{UNBOLD}")
|
|
389
415
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
# last_big_tensor += 1
|
|
399
|
-
# total += size
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
# tensor_no = 0
|
|
403
|
-
# for k,t in sd.items():
|
|
404
|
-
# big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
|
|
405
|
-
# if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
|
|
406
|
-
# current_big_tensor = big_tensors[big_tensor_no]
|
|
407
|
-
# # if isinstance(p, QTensor):
|
|
408
|
-
# # length1 = torch.numel(p._data) * p._data.element_size()
|
|
409
|
-
# # p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
|
|
410
|
-
# # length2 = torch.numel(p._scale) * p._scale.element_size()
|
|
411
|
-
# # p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
|
|
412
|
-
# # else:
|
|
413
|
-
# # length = torch.numel(p.data) * p.data.element_size()
|
|
414
|
-
# # p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
|
|
415
|
-
# length = torch.numel(t) * t.data.element_size()
|
|
416
|
-
# t = _move_to_pinned_tensor(t, current_big_tensor, offset, length)
|
|
417
|
-
# sd[k] = t
|
|
418
|
-
# tensor_no += 1
|
|
419
|
-
|
|
420
|
-
# global total_pinned_bytes
|
|
421
|
-
# total_pinned_bytes += total
|
|
422
|
-
|
|
423
|
-
# if verboseLevel >=1:
|
|
424
|
-
# if total_tensor_bytes == total:
|
|
425
|
-
# print(f"The whole model was pinned to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
|
|
426
|
-
# else:
|
|
427
|
-
# print(f"{total/ONE_MB:.2f} MB were pinned to reserved RAM out of {total_tensor_bytes/ONE_MB:.2f} MB")
|
|
428
|
-
|
|
429
|
-
# model._already_pinned = True
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
# return
|
|
416
|
+
def _extract_num_from_str(num_in_str):
|
|
417
|
+
for i in range(len(num_in_str)):
|
|
418
|
+
if not num_in_str[-i-1:].isnumeric():
|
|
419
|
+
if i == 0:
|
|
420
|
+
return num_in_str, -1
|
|
421
|
+
else:
|
|
422
|
+
return num_in_str[: -i], int(num_in_str[-i:])
|
|
423
|
+
return "", int(num_in_str)
|
|
433
424
|
|
|
434
425
|
def _quantize_dirty_hack(model):
|
|
435
426
|
# dirty hack: add a hook on state_dict() to return a fake non quantized state_dict if called by Lora Diffusers initialization functions
|
|
@@ -536,10 +527,14 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
|
|
|
536
527
|
prev_blocks_prefix = None
|
|
537
528
|
|
|
538
529
|
if hasattr(model_to_quantize, "_quanto_map"):
|
|
530
|
+
for k, entry in model_to_quantize._quanto_map.items():
|
|
531
|
+
weights = entry["weights"]
|
|
532
|
+
print(f"Model '{model_id}' is already quantized to format '{weights}'")
|
|
533
|
+
return False
|
|
539
534
|
print(f"Model '{model_id}' is already quantized")
|
|
540
535
|
return False
|
|
541
|
-
|
|
542
|
-
print(f"Quantization of model '{model_id}' started")
|
|
536
|
+
|
|
537
|
+
print(f"Quantization of model '{model_id}' started to format '{weights}'")
|
|
543
538
|
|
|
544
539
|
for submodule_name, submodule in model_to_quantize.named_modules():
|
|
545
540
|
if isinstance(submodule, QModuleMixin):
|
|
@@ -594,18 +589,18 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
|
|
|
594
589
|
if verboseLevel >=2:
|
|
595
590
|
print(f"Total Excluded {total_excluded/ONE_MB:.1f} MB oF {total_size/ONE_MB:.1f} that is {perc_excluded*100:.2f}%")
|
|
596
591
|
if perc_excluded >= 0.10:
|
|
597
|
-
print(f"Too many
|
|
592
|
+
print(f"Too many modules are excluded, there is something wrong with the selection, switch back to full quantization.")
|
|
598
593
|
exclude_list = None
|
|
599
594
|
|
|
600
595
|
|
|
601
596
|
#quantize(model_to_quantize,weights, exclude= exclude_list)
|
|
602
|
-
|
|
597
|
+
|
|
603
598
|
for name, m in model_to_quantize.named_modules():
|
|
604
599
|
if exclude_list is None or not any( name == module_name for module_name in exclude_list):
|
|
605
600
|
_quantize_submodule(model_to_quantize, name, m, weights=weights, activations=None, optimizer=None)
|
|
606
601
|
|
|
607
|
-
# force read non quantized parameters so that their lazy tensors and corresponding mmap are released
|
|
608
|
-
# otherwise we may end up
|
|
602
|
+
# force to read non quantized parameters so that their lazy tensors and corresponding mmap are released
|
|
603
|
+
# otherwise we may end up keeping in memory both the quantized and the non quantize model
|
|
609
604
|
for m in model_to_quantize.modules():
|
|
610
605
|
# do not read quantized weights (detected them directly or behind an adapter)
|
|
611
606
|
if isinstance(m, QModuleMixin) or hasattr(m, "base_layer") and isinstance(m.base_layer, QModuleMixin):
|
|
@@ -620,18 +615,271 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
|
|
|
620
615
|
b.data = b.data + 0
|
|
621
616
|
|
|
622
617
|
|
|
618
|
+
|
|
623
619
|
freeze(model_to_quantize)
|
|
624
620
|
torch.cuda.empty_cache()
|
|
625
621
|
gc.collect()
|
|
626
622
|
quantization_map = _quantization_map(model_to_quantize)
|
|
627
623
|
model_to_quantize._quanto_map = quantization_map
|
|
628
624
|
|
|
625
|
+
if hasattr(model_to_quantize, "_already_pinned"):
|
|
626
|
+
delattr(model_to_quantize, "_already_pinned")
|
|
627
|
+
|
|
629
628
|
_quantize_dirty_hack(model_to_quantize)
|
|
630
629
|
|
|
631
630
|
print(f"Quantization of model '{model_id}' done")
|
|
632
631
|
|
|
633
632
|
return True
|
|
634
633
|
|
|
634
|
+
def load_loras_into_model(model, lora_path, lora_multi = None, verboseLevel = -1):
|
|
635
|
+
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
636
|
+
|
|
637
|
+
if inject_adapter_in_model == None or set_weights_and_activate_adapters == None or get_peft_kwargs == None:
|
|
638
|
+
raise Exception("Unable to load Lora, missing 'peft' and / or 'diffusers' modules")
|
|
639
|
+
|
|
640
|
+
if not isinstance(lora_path, list):
|
|
641
|
+
lora_path = [lora_path]
|
|
642
|
+
|
|
643
|
+
if lora_multi is None:
|
|
644
|
+
lora_multi = [1. for _ in lora_path]
|
|
645
|
+
|
|
646
|
+
for i, path in enumerate(lora_path):
|
|
647
|
+
adapter_name = str(i)
|
|
648
|
+
|
|
649
|
+
state_dict = safetensors2.torch_load_file(path)
|
|
650
|
+
|
|
651
|
+
keys = list(state_dict.keys())
|
|
652
|
+
if len(keys) == 0:
|
|
653
|
+
raise Exception(f"Empty Lora '{path}'")
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
network_alphas = {}
|
|
657
|
+
for k in keys:
|
|
658
|
+
if "alpha" in k:
|
|
659
|
+
alpha_value = state_dict.pop(k)
|
|
660
|
+
if not ( (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
|
|
661
|
+
alpha_value, float
|
|
662
|
+
)):
|
|
663
|
+
network_alphas[k] = torch.tensor( float(alpha_value.item() ) )
|
|
664
|
+
|
|
665
|
+
pos = keys[0].find(".")
|
|
666
|
+
prefix = keys[0][0:pos]
|
|
667
|
+
if not any( prefix.startswith(some_prefix) for some_prefix in ["diffusion_model", "transformer"]):
|
|
668
|
+
msg = f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format."
|
|
669
|
+
raise Exception(msg)
|
|
670
|
+
|
|
671
|
+
transformer = model
|
|
672
|
+
|
|
673
|
+
transformer_keys = [k for k in keys if k.startswith(prefix)]
|
|
674
|
+
state_dict = {
|
|
675
|
+
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
|
676
|
+
}
|
|
677
|
+
|
|
678
|
+
sd_keys = state_dict.keys()
|
|
679
|
+
if len(sd_keys) == 0:
|
|
680
|
+
print(f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format.")
|
|
681
|
+
return
|
|
682
|
+
|
|
683
|
+
# is_correct_format = all("lora" in key for key in state_dict.keys())
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
# check with first key if is not in peft format
|
|
689
|
+
# first_key = next(iter(state_dict.keys()))
|
|
690
|
+
# if "lora_A" not in first_key:
|
|
691
|
+
# state_dict = convert_unet_state_dict_to_peft(state_dict)
|
|
692
|
+
|
|
693
|
+
if adapter_name in getattr(transformer, "peft_config", {}):
|
|
694
|
+
raise ValueError(
|
|
695
|
+
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
rank = {}
|
|
699
|
+
for key, val in state_dict.items():
|
|
700
|
+
if "lora_B" in key:
|
|
701
|
+
rank[key] = val.shape[1]
|
|
702
|
+
|
|
703
|
+
if network_alphas is not None and len(network_alphas) >= 1:
|
|
704
|
+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
|
705
|
+
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
|
706
|
+
|
|
707
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
|
708
|
+
|
|
709
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
|
710
|
+
peft_kwargs = {}
|
|
711
|
+
peft_kwargs["low_cpu_mem_usage"] = True
|
|
712
|
+
inject_adapter_in_model(lora_config, model, adapter_name=adapter_name, **peft_kwargs)
|
|
713
|
+
|
|
714
|
+
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
|
715
|
+
|
|
716
|
+
warn_msg = ""
|
|
717
|
+
if incompatible_keys is not None:
|
|
718
|
+
# Check only for unexpected keys.
|
|
719
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
|
720
|
+
if unexpected_keys:
|
|
721
|
+
pass
|
|
722
|
+
if verboseLevel >=1:
|
|
723
|
+
print(f"Lora '{path}' was loaded in model '{_get_module_name(model)}'")
|
|
724
|
+
set_weights_and_activate_adapters(model,[ str(i) for i in range(len(lora_multi))], lora_multi)
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
|
|
728
|
+
"""
|
|
729
|
+
quick version of .LoadfromPretrained of the transformers library
|
|
730
|
+
used to build a model and load the corresponding weights (quantized or not)
|
|
731
|
+
"""
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
import os.path
|
|
735
|
+
from accelerate import init_empty_weights
|
|
736
|
+
|
|
737
|
+
if not (model_path.endswith(".sft") or model_path.endswith(".safetensors")):
|
|
738
|
+
raise Exception("full model path to file expected")
|
|
739
|
+
|
|
740
|
+
model_path = _get_model(model_path)
|
|
741
|
+
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
742
|
+
|
|
743
|
+
with safetensors2.safe_open(model_path) as f:
|
|
744
|
+
metadata = f.metadata()
|
|
745
|
+
|
|
746
|
+
if metadata is None:
|
|
747
|
+
transformer_config = None
|
|
748
|
+
else:
|
|
749
|
+
transformer_config = metadata.get("config", None)
|
|
750
|
+
|
|
751
|
+
if transformer_config == None:
|
|
752
|
+
config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
|
|
753
|
+
|
|
754
|
+
if not os.path.isfile(config_fullpath):
|
|
755
|
+
raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
|
|
756
|
+
|
|
757
|
+
with open(config_fullpath, "r", encoding="utf-8") as reader:
|
|
758
|
+
text = reader.read()
|
|
759
|
+
transformer_config= json.loads(text)
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
if "architectures" in transformer_config:
|
|
763
|
+
architectures = transformer_config["architectures"]
|
|
764
|
+
class_name = architectures[0]
|
|
765
|
+
|
|
766
|
+
module = __import__("transformers")
|
|
767
|
+
map = { "T5WithLMHeadModel" : "T5EncoderModel"}
|
|
768
|
+
class_name = map.get(class_name, class_name)
|
|
769
|
+
transfomer_class = getattr(module, class_name)
|
|
770
|
+
from transformers import AutoConfig
|
|
771
|
+
|
|
772
|
+
import tempfile
|
|
773
|
+
with tempfile.NamedTemporaryFile("w", delete = False, encoding ="utf-8") as fp:
|
|
774
|
+
fp.write(json.dumps(transformer_config))
|
|
775
|
+
fp.close()
|
|
776
|
+
config_obj = AutoConfig.from_pretrained(fp.name)
|
|
777
|
+
os.remove(fp.name)
|
|
778
|
+
|
|
779
|
+
#needed to keep inits of non persistent buffers
|
|
780
|
+
with init_empty_weights():
|
|
781
|
+
model = transfomer_class(config_obj)
|
|
782
|
+
|
|
783
|
+
model = model.base_model
|
|
784
|
+
|
|
785
|
+
elif "_class_name" in transformer_config:
|
|
786
|
+
class_name = transformer_config["_class_name"]
|
|
787
|
+
|
|
788
|
+
module = __import__("diffusers")
|
|
789
|
+
transfomer_class = getattr(module, class_name)
|
|
790
|
+
|
|
791
|
+
with init_empty_weights():
|
|
792
|
+
model = transfomer_class.from_config(transformer_config)
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
torch.set_default_device('cpu')
|
|
796
|
+
|
|
797
|
+
model._config = transformer_config
|
|
798
|
+
|
|
799
|
+
load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, verboseLevel=verboseLevel )
|
|
800
|
+
|
|
801
|
+
return model
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
def load_model_data(model, file_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
|
|
806
|
+
"""
|
|
807
|
+
Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
|
|
808
|
+
"""
|
|
809
|
+
|
|
810
|
+
file_path = _get_model(file_path)
|
|
811
|
+
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
812
|
+
|
|
813
|
+
model = _remove_model_wrapper(model)
|
|
814
|
+
|
|
815
|
+
# if pinToMemory and do_quantize:
|
|
816
|
+
# raise Exception("Pinning and Quantization can not be used at the same time")
|
|
817
|
+
|
|
818
|
+
if not (".safetensors" in file_path or ".sft" in file_path):
|
|
819
|
+
if pinToMemory:
|
|
820
|
+
raise Exception("Pinning to memory while loading only supported for safe tensors files")
|
|
821
|
+
state_dict = torch.load(file_path, weights_only=True)
|
|
822
|
+
if "module" in state_dict:
|
|
823
|
+
state_dict = state_dict["module"]
|
|
824
|
+
else:
|
|
825
|
+
state_dict, metadata = _safetensors_load_file(file_path)
|
|
826
|
+
|
|
827
|
+
if metadata is None:
|
|
828
|
+
quantization_map = None
|
|
829
|
+
else:
|
|
830
|
+
quantization_map = metadata.get("quantization_map", None)
|
|
831
|
+
config = metadata.get("config", None)
|
|
832
|
+
if config is not None:
|
|
833
|
+
model._config = config
|
|
834
|
+
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
if quantization_map is None:
|
|
838
|
+
pos = str.rfind(file_path, ".")
|
|
839
|
+
if pos > 0:
|
|
840
|
+
quantization_map_path = file_path[:pos]
|
|
841
|
+
quantization_map_path += "_map.json"
|
|
842
|
+
|
|
843
|
+
if os.path.isfile(quantization_map_path):
|
|
844
|
+
with open(quantization_map_path, 'r') as f:
|
|
845
|
+
quantization_map = json.load(f)
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
if quantization_map is None :
|
|
850
|
+
if "quanto" in file_path and not do_quantize:
|
|
851
|
+
print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
|
|
852
|
+
else:
|
|
853
|
+
_requantize(model, state_dict, quantization_map)
|
|
854
|
+
|
|
855
|
+
missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
|
|
856
|
+
# if len(missing_keys) > 0:
|
|
857
|
+
# sd_crap = { k : None for k in missing_keys}
|
|
858
|
+
# missing_keys , unexpected_keys = model.load_state_dict(sd_crap, strict =False, assign = True )
|
|
859
|
+
del state_dict
|
|
860
|
+
|
|
861
|
+
for k,p in model.named_parameters():
|
|
862
|
+
if p.is_meta:
|
|
863
|
+
txt = f"Incompatible State Dictionary or 'Init_Empty_Weights' not set since parameter '{k}' has no data"
|
|
864
|
+
raise Exception(txt)
|
|
865
|
+
for k,b in model.named_buffers():
|
|
866
|
+
if b.is_meta:
|
|
867
|
+
txt = f"Incompatible State Dictionary or 'Init_Empty_Weights' not set since buffer '{k}' has no data"
|
|
868
|
+
raise Exception(txt)
|
|
869
|
+
|
|
870
|
+
if do_quantize:
|
|
871
|
+
if quantization_map is None:
|
|
872
|
+
if _quantize(model, quantizationType, verboseLevel=verboseLevel, model_id=file_path):
|
|
873
|
+
quantization_map = model._quanto_map
|
|
874
|
+
else:
|
|
875
|
+
if verboseLevel >=1:
|
|
876
|
+
print("Model already quantized")
|
|
877
|
+
|
|
878
|
+
if pinToMemory:
|
|
879
|
+
_pin_to_memory(model, file_path, partialPinning = partialPinning, verboseLevel = verboseLevel)
|
|
880
|
+
|
|
881
|
+
return
|
|
882
|
+
|
|
635
883
|
def get_model_name(model):
|
|
636
884
|
return model.name
|
|
637
885
|
|
|
@@ -663,6 +911,7 @@ class offload:
|
|
|
663
911
|
self.async_transfers = False
|
|
664
912
|
global last_offload_obj
|
|
665
913
|
last_offload_obj = self
|
|
914
|
+
|
|
666
915
|
|
|
667
916
|
def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name):
|
|
668
917
|
|
|
@@ -684,15 +933,25 @@ class offload:
|
|
|
684
933
|
|
|
685
934
|
for k,p in submodule.named_parameters(recurse=False):
|
|
686
935
|
if isinstance(p, QTensor):
|
|
687
|
-
blocks_params.append( (submodule, k, p
|
|
688
|
-
|
|
689
|
-
|
|
936
|
+
blocks_params.append( (submodule, k, p ) )
|
|
937
|
+
|
|
938
|
+
if p._qtype == qint4:
|
|
939
|
+
if hasattr(p,"_scale_shift"):
|
|
940
|
+
blocks_params_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
|
|
941
|
+
blocks_params_size += torch.numel(p._data._data) * p._data._data.element_size()
|
|
942
|
+
else:
|
|
943
|
+
blocks_params_size += torch.numel(p._scale) * p._scale.element_size()
|
|
944
|
+
blocks_params_size += torch.numel(p._shift) * p._shift.element_size()
|
|
945
|
+
blocks_params_size += torch.numel(p._data._data) * p._data._data.element_size()
|
|
946
|
+
else:
|
|
947
|
+
blocks_params_size += torch.numel(p._scale) * p._scale.element_size()
|
|
948
|
+
blocks_params_size += torch.numel(p._data) * p._data.element_size()
|
|
690
949
|
else:
|
|
691
|
-
blocks_params.append( (submodule, k, p
|
|
692
|
-
blocks_params_size += p.data.
|
|
950
|
+
blocks_params.append( (submodule, k, p ) )
|
|
951
|
+
blocks_params_size += torch.numel(p.data) * p.data.element_size()
|
|
693
952
|
|
|
694
953
|
for k, p in submodule.named_buffers(recurse=False):
|
|
695
|
-
blocks_params.append( (submodule, k, p
|
|
954
|
+
blocks_params.append( (submodule, k, p) )
|
|
696
955
|
blocks_params_size += p.data.nbytes
|
|
697
956
|
|
|
698
957
|
|
|
@@ -710,34 +969,28 @@ class offload:
|
|
|
710
969
|
return False
|
|
711
970
|
return True
|
|
712
971
|
|
|
713
|
-
|
|
972
|
+
@torch.compiler.disable()
|
|
973
|
+
def gpu_load_blocks(self, model_id, blocks_name):
|
|
714
974
|
# cl = clock.start()
|
|
715
975
|
|
|
716
976
|
if blocks_name != None:
|
|
717
977
|
self.loaded_blocks[model_id] = blocks_name
|
|
718
978
|
|
|
719
979
|
entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
|
|
720
|
-
|
|
721
|
-
def cpu_to_gpu(stream_to_use, blocks_params
|
|
980
|
+
|
|
981
|
+
def cpu_to_gpu(stream_to_use, blocks_params): #, record_for_stream = None
|
|
722
982
|
with torch.cuda.stream(stream_to_use):
|
|
723
983
|
for param in blocks_params:
|
|
724
|
-
parent_module, n,
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
else:
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
if record_for_stream != None:
|
|
736
|
-
if isinstance(p, QTensor):
|
|
737
|
-
q._data.record_stream(record_for_stream)
|
|
738
|
-
q._scale.record_stream(record_for_stream)
|
|
739
|
-
else:
|
|
740
|
-
p.data.record_stream(record_for_stream)
|
|
984
|
+
parent_module, n, p = param
|
|
985
|
+
q = p.to("cuda", non_blocking=True)
|
|
986
|
+
q = torch.nn.Parameter(q , requires_grad=False)
|
|
987
|
+
setattr(parent_module, n , q)
|
|
988
|
+
# if record_for_stream != None:
|
|
989
|
+
# if isinstance(p, QTensor):
|
|
990
|
+
# q._data.record_stream(record_for_stream)
|
|
991
|
+
# q._scale.record_stream(record_for_stream)
|
|
992
|
+
# else:
|
|
993
|
+
# p.data.record_stream(record_for_stream)
|
|
741
994
|
|
|
742
995
|
|
|
743
996
|
if self.verboseLevel >=2:
|
|
@@ -762,7 +1015,7 @@ class offload:
|
|
|
762
1015
|
# cl.stop()
|
|
763
1016
|
# print(f"load time: {cl.format_time_gap()}")
|
|
764
1017
|
|
|
765
|
-
|
|
1018
|
+
@torch.compiler.disable()
|
|
766
1019
|
def gpu_unload_blocks(self, model_id, blocks_name):
|
|
767
1020
|
# cl = clock.start()
|
|
768
1021
|
if blocks_name != None:
|
|
@@ -776,23 +1029,14 @@ class offload:
|
|
|
776
1029
|
print(f"Unloading model {blocks_name} ({model_name}) from GPU")
|
|
777
1030
|
|
|
778
1031
|
blocks_params = self.blocks_of_modules[blocks_name]
|
|
779
|
-
|
|
780
1032
|
for param in blocks_params:
|
|
781
|
-
parent_module, n,
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
# need to change the parameter directly from the module as it can't be swapped in place due to a memory leak in the pytorch compiler
|
|
785
|
-
q = WeightQBytesTensor.create(p.qtype, p.axis, p.size(), p.stride(), data, scale, activation_qtype=p.activation_qtype, requires_grad=p.requires_grad )
|
|
786
|
-
q = torch.nn.Parameter(q , requires_grad=False)
|
|
787
|
-
setattr(parent_module, n , q)
|
|
788
|
-
del p
|
|
789
|
-
else:
|
|
790
|
-
p.data = data
|
|
791
|
-
|
|
1033
|
+
parent_module, n, p = param
|
|
1034
|
+
q = torch.nn.Parameter(p , requires_grad=False)
|
|
1035
|
+
setattr(parent_module, n , q)
|
|
792
1036
|
# cl.stop()
|
|
793
1037
|
# print(f"unload time: {cl.format_time_gap()}")
|
|
794
1038
|
|
|
795
|
-
|
|
1039
|
+
# @torch.compiler.disable()
|
|
796
1040
|
def gpu_load(self, model_id):
|
|
797
1041
|
model = self.models[model_id]
|
|
798
1042
|
self.active_models.append(model)
|
|
@@ -824,8 +1068,8 @@ class offload:
|
|
|
824
1068
|
if torch.is_tensor(arg):
|
|
825
1069
|
if arg.dtype == torch.float32:
|
|
826
1070
|
arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
|
|
827
|
-
|
|
828
|
-
arg = arg.cuda(non_blocking=True)
|
|
1071
|
+
elif not arg.is_cuda:
|
|
1072
|
+
arg = arg.cuda(non_blocking=True)
|
|
829
1073
|
new_args.append(arg)
|
|
830
1074
|
|
|
831
1075
|
for k in kwargs:
|
|
@@ -833,7 +1077,7 @@ class offload:
|
|
|
833
1077
|
if torch.is_tensor(arg):
|
|
834
1078
|
if arg.dtype == torch.float32:
|
|
835
1079
|
arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
|
|
836
|
-
|
|
1080
|
+
elif not arg.is_cuda:
|
|
837
1081
|
arg = arg.cuda(non_blocking=True)
|
|
838
1082
|
new_kwargs[k]= arg
|
|
839
1083
|
|
|
@@ -874,10 +1118,10 @@ class offload:
|
|
|
874
1118
|
|
|
875
1119
|
return False
|
|
876
1120
|
|
|
877
|
-
def
|
|
1121
|
+
def hook_preload_blocks_for_compilation(self, target_module, model_id,blocks_name, context):
|
|
878
1122
|
|
|
879
|
-
@torch.compiler.disable()
|
|
880
|
-
def
|
|
1123
|
+
# @torch.compiler.disable()
|
|
1124
|
+
def preload_blocks_for_compile(module, *args, **kwargs):
|
|
881
1125
|
some_context = context #for debugging
|
|
882
1126
|
if blocks_name == None:
|
|
883
1127
|
if self.ready_to_check_mem():
|
|
@@ -891,12 +1135,17 @@ class offload:
|
|
|
891
1135
|
self.empty_cache_if_needed()
|
|
892
1136
|
self.loaded_blocks[model_id] = blocks_name
|
|
893
1137
|
self.gpu_load_blocks(model_id, blocks_name)
|
|
894
|
-
|
|
895
|
-
|
|
1138
|
+
# need to be registered before the forward not to be break the efficiency of the compilation chain
|
|
1139
|
+
# it should be at the top of the compilation as this type of hook in the middle of a chain seems to break memory performance
|
|
1140
|
+
target_module.register_forward_pre_hook(preload_blocks_for_compile)
|
|
896
1141
|
|
|
897
1142
|
|
|
898
1143
|
def hook_check_empty_cache_needed(self, target_module, model_id,blocks_name, previous_method, context):
|
|
899
1144
|
|
|
1145
|
+
qint4quantization = isinstance(target_module, QModuleMixin) and target_module.weight!= None and target_module.weight.qtype == qint4
|
|
1146
|
+
if qint4quantization:
|
|
1147
|
+
pass
|
|
1148
|
+
|
|
900
1149
|
def check_empty_cuda_cache(module, *args, **kwargs):
|
|
901
1150
|
# if self.ready_to_check_mem():
|
|
902
1151
|
# self.empty_cache_if_needed()
|
|
@@ -912,6 +1161,8 @@ class offload:
|
|
|
912
1161
|
self.empty_cache_if_needed()
|
|
913
1162
|
self.loaded_blocks[model_id] = blocks_name
|
|
914
1163
|
self.gpu_load_blocks(model_id, blocks_name)
|
|
1164
|
+
if qint4quantization:
|
|
1165
|
+
args, kwargs = self.move_args_to_gpu(*args, **kwargs)
|
|
915
1166
|
|
|
916
1167
|
return previous_method(*args, **kwargs)
|
|
917
1168
|
|
|
@@ -959,177 +1210,18 @@ class offload:
|
|
|
959
1210
|
print(f"Hooked in model '{model_id}' ({model_name})")
|
|
960
1211
|
|
|
961
1212
|
|
|
962
|
-
|
|
963
|
-
# def unhook_module(module: torch.nn.Module):
|
|
964
|
-
# if not hasattr(module,"_mm_id"):
|
|
965
|
-
# return
|
|
966
|
-
|
|
967
|
-
# delattr(module, "_mm_id")
|
|
968
|
-
|
|
969
|
-
# def unhook_all(parent_module: torch.nn.Module):
|
|
970
|
-
# for module in parent_module.components.items():
|
|
971
|
-
# self.unhook_module(module)
|
|
972
|
-
|
|
973
|
-
def fast_load_transformers_model(model_path: str, do_quantize = False, quantization_type = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
|
|
974
|
-
"""
|
|
975
|
-
quick version of .LoadfromPretrained of the transformers library
|
|
976
|
-
used to build a model and load the corresponding weights (quantized or not)
|
|
977
|
-
"""
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
import os.path
|
|
981
|
-
from accelerate import init_empty_weights
|
|
982
|
-
|
|
983
|
-
if not (model_path.endswith(".sft") or model_path.endswith(".safetensors")):
|
|
984
|
-
raise Exception("full model path to file expected")
|
|
985
|
-
|
|
986
|
-
model_path = _get_model(model_path)
|
|
987
|
-
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
988
|
-
|
|
989
|
-
with safetensors2.safe_open(model_path) as f:
|
|
990
|
-
metadata = f.metadata()
|
|
991
|
-
|
|
992
|
-
if metadata is None:
|
|
993
|
-
transformer_config = None
|
|
994
|
-
else:
|
|
995
|
-
transformer_config = metadata.get("config", None)
|
|
996
|
-
|
|
997
|
-
if transformer_config == None:
|
|
998
|
-
config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
|
|
999
|
-
|
|
1000
|
-
if not os.path.isfile(config_fullpath):
|
|
1001
|
-
raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
|
|
1002
|
-
|
|
1003
|
-
with open(config_fullpath, "r", encoding="utf-8") as reader:
|
|
1004
|
-
text = reader.read()
|
|
1005
|
-
transformer_config= json.loads(text)
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
if "architectures" in transformer_config:
|
|
1009
|
-
architectures = transformer_config["architectures"]
|
|
1010
|
-
class_name = architectures[0]
|
|
1011
|
-
|
|
1012
|
-
module = __import__("transformers")
|
|
1013
|
-
transfomer_class = getattr(module, class_name)
|
|
1014
|
-
from transformers import AutoConfig
|
|
1015
|
-
|
|
1016
|
-
import tempfile
|
|
1017
|
-
with tempfile.NamedTemporaryFile("w", delete = False, encoding ="utf-8") as fp:
|
|
1018
|
-
fp.write(json.dumps(transformer_config))
|
|
1019
|
-
fp.close()
|
|
1020
|
-
config_obj = AutoConfig.from_pretrained(fp.name)
|
|
1021
|
-
os.remove(fp.name)
|
|
1022
|
-
|
|
1023
|
-
#needed to keep inits of non persistent buffers
|
|
1024
|
-
with init_empty_weights():
|
|
1025
|
-
model = transfomer_class(config_obj)
|
|
1026
|
-
|
|
1027
|
-
model = model.base_model
|
|
1028
|
-
|
|
1029
|
-
elif "_class_name" in transformer_config:
|
|
1030
|
-
class_name = transformer_config["_class_name"]
|
|
1031
|
-
|
|
1032
|
-
module = __import__("diffusers")
|
|
1033
|
-
transfomer_class = getattr(module, class_name)
|
|
1034
|
-
|
|
1035
|
-
with init_empty_weights():
|
|
1036
|
-
model = transfomer_class.from_config(transformer_config)
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
torch.set_default_device('cpu')
|
|
1040
|
-
|
|
1041
|
-
model._config = transformer_config
|
|
1042
|
-
|
|
1043
|
-
load_model_data(model,model_path, do_quantize = do_quantize, quantization_type = quantization_type, pinToMemory= pinToMemory, partialPinning= partialPinning, verboseLevel=verboseLevel )
|
|
1044
|
-
|
|
1045
|
-
return model
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
def load_model_data(model, file_path: str, do_quantize = False, quantization_type = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
|
|
1050
|
-
"""
|
|
1051
|
-
Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
|
|
1052
|
-
"""
|
|
1053
|
-
|
|
1054
|
-
file_path = _get_model(file_path)
|
|
1055
|
-
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
1056
|
-
|
|
1057
|
-
model = _remove_model_wrapper(model)
|
|
1058
|
-
|
|
1059
|
-
# if pinToMemory and do_quantize:
|
|
1060
|
-
# raise Exception("Pinning and Quantization can not be used at the same time")
|
|
1061
|
-
|
|
1062
|
-
if not (".safetensors" in file_path or ".sft" in file_path):
|
|
1063
|
-
if pinToMemory:
|
|
1064
|
-
raise Exception("Pinning to memory while loading only supported for safe tensors files")
|
|
1065
|
-
state_dict = torch.load(file_path, weights_only=True)
|
|
1066
|
-
if "module" in state_dict:
|
|
1067
|
-
state_dict = state_dict["module"]
|
|
1068
|
-
else:
|
|
1069
|
-
state_dict, metadata = _safetensors_load_file(file_path)
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
# if pinToMemory:
|
|
1073
|
-
# _pin_to_memory_sd(model,state_dict, file_path, partialPinning = partialPinning, perc_reserved_mem_max = perc_reserved_mem_max, verboseLevel = verboseLevel)
|
|
1074
|
-
|
|
1075
|
-
# with safetensors2.safe_open(file_path) as f:
|
|
1076
|
-
# metadata = f.metadata()
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
if metadata is None:
|
|
1080
|
-
quantization_map = None
|
|
1081
|
-
else:
|
|
1082
|
-
quantization_map = metadata.get("quantization_map", None)
|
|
1083
|
-
config = metadata.get("config", None)
|
|
1084
|
-
if config is not None:
|
|
1085
|
-
model._config = config
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
if quantization_map is None:
|
|
1090
|
-
pos = str.rfind(file_path, ".")
|
|
1091
|
-
if pos > 0:
|
|
1092
|
-
quantization_map_path = file_path[:pos]
|
|
1093
|
-
quantization_map_path += "_map.json"
|
|
1094
|
-
|
|
1095
|
-
if os.path.isfile(quantization_map_path):
|
|
1096
|
-
with open(quantization_map_path, 'r') as f:
|
|
1097
|
-
quantization_map = json.load(f)
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
if quantization_map is None :
|
|
1102
|
-
if "quanto" in file_path and not do_quantize:
|
|
1103
|
-
print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
|
|
1104
|
-
else:
|
|
1105
|
-
_requantize(model, state_dict, quantization_map)
|
|
1106
|
-
|
|
1107
|
-
missing_keys , unexpected_keys = model.load_state_dict(state_dict, strict = quantization_map is None, assign = True )
|
|
1108
|
-
del state_dict
|
|
1109
|
-
|
|
1110
|
-
if do_quantize:
|
|
1111
|
-
if quantization_map is None:
|
|
1112
|
-
if _quantize(model, quantization_type, verboseLevel=verboseLevel, model_id=file_path):
|
|
1113
|
-
quantization_map = model._quanto_map
|
|
1114
|
-
else:
|
|
1115
|
-
if verboseLevel >=1:
|
|
1116
|
-
print("Model already quantized")
|
|
1117
|
-
|
|
1118
|
-
if pinToMemory:
|
|
1119
|
-
_pin_to_memory(model, file_path, partialPinning = partialPinning, verboseLevel = verboseLevel)
|
|
1120
|
-
|
|
1121
|
-
return
|
|
1122
|
-
|
|
1123
|
-
def save_model(model, file_path, do_quantize = False, quantization_type = qint8, verboseLevel = -1 ):
|
|
1213
|
+
def save_model(model, file_path, do_quantize = False, quantizationType = qint8, verboseLevel = -1, config_file_path = None ):
|
|
1124
1214
|
"""save the weights of a model and quantize them if requested
|
|
1125
1215
|
These weights can be loaded again using 'load_model_data'
|
|
1126
1216
|
"""
|
|
1127
1217
|
|
|
1128
1218
|
config = None
|
|
1129
|
-
|
|
1130
1219
|
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
1131
|
-
|
|
1132
|
-
|
|
1220
|
+
if config_file_path !=None:
|
|
1221
|
+
with open(config_file_path, "r", encoding="utf-8") as reader:
|
|
1222
|
+
text = reader.read()
|
|
1223
|
+
config= json.loads(text)
|
|
1224
|
+
elif hasattr(model, "_config"):
|
|
1133
1225
|
config = model._config
|
|
1134
1226
|
elif hasattr(model, "config"):
|
|
1135
1227
|
config_fullpath = None
|
|
@@ -1147,7 +1239,7 @@ def save_model(model, file_path, do_quantize = False, quantization_type = qint8,
|
|
|
1147
1239
|
config= json.loads(text)
|
|
1148
1240
|
|
|
1149
1241
|
if do_quantize:
|
|
1150
|
-
_quantize(model, weights=
|
|
1242
|
+
_quantize(model, weights=quantizationType, model_id=file_path)
|
|
1151
1243
|
|
|
1152
1244
|
quantization_map = getattr(model, "_quanto_map", None)
|
|
1153
1245
|
|
|
@@ -1155,12 +1247,12 @@ def save_model(model, file_path, do_quantize = False, quantization_type = qint8,
|
|
|
1155
1247
|
print(f"Saving file '{file_path}")
|
|
1156
1248
|
safetensors2.torch_write_file(model.state_dict(), file_path , quantization_map = quantization_map, config = config)
|
|
1157
1249
|
if verboseLevel >=1:
|
|
1158
|
-
print(f"File '{file_path} saved")
|
|
1250
|
+
print(f"File '{file_path}' saved")
|
|
1159
1251
|
|
|
1160
1252
|
|
|
1161
1253
|
|
|
1162
1254
|
|
|
1163
|
-
def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = True, extraModelsToQuantize = None, budgets= 0, asyncTransfers = True, compile = False, perc_reserved_mem_max = 0, verboseLevel = -1):
|
|
1255
|
+
def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, asyncTransfers = True, compile = False, perc_reserved_mem_max = 0, verboseLevel = -1):
|
|
1164
1256
|
"""Hook to a pipeline or a group of modules in order to reduce their VRAM requirements:
|
|
1165
1257
|
pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
|
|
1166
1258
|
quantizeTransformer: set True by default will quantize on the fly the video / image model
|
|
@@ -1238,13 +1330,14 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
|
|
|
1238
1330
|
self.anyCompiledModule = compileAllModels or len(modelsToCompile)>0
|
|
1239
1331
|
if self.anyCompiledModule:
|
|
1240
1332
|
torch._dynamo.config.cache_size_limit = 10000
|
|
1333
|
+
torch.compiler.reset()
|
|
1334
|
+
|
|
1241
1335
|
# torch._logging.set_logs(recompiles=True)
|
|
1242
1336
|
# torch._inductor.config.realize_opcount_threshold = 100 # workaround bug "AssertionError: increase TRITON_MAX_BLOCK['X'] to 4096."
|
|
1243
1337
|
|
|
1244
1338
|
max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
|
|
1245
1339
|
|
|
1246
1340
|
estimatesBytesToPin = 0
|
|
1247
|
-
|
|
1248
1341
|
for model_id in models:
|
|
1249
1342
|
current_model: torch.nn.Module = models[model_id]
|
|
1250
1343
|
# make sure that no RAM or GPU memory is not allocated for gradiant / training
|
|
@@ -1252,19 +1345,30 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
|
|
|
1252
1345
|
|
|
1253
1346
|
# if the model has just been quantized so there is no need to quantize it again
|
|
1254
1347
|
if model_id in models_to_quantize:
|
|
1255
|
-
_quantize(current_model, weights=
|
|
1348
|
+
_quantize(current_model, weights=quantizationType, verboseLevel = self.verboseLevel, model_id=model_id)
|
|
1256
1349
|
|
|
1257
1350
|
modelPinned = (pinAllModels or model_id in modelsToPin) and not hasattr(current_model,"_already_pinned")
|
|
1258
1351
|
|
|
1259
|
-
current_model_size = 0
|
|
1260
|
-
|
|
1261
|
-
for p in current_model.
|
|
1352
|
+
current_model_size = 0
|
|
1353
|
+
|
|
1354
|
+
for n, p in current_model.named_parameters():
|
|
1355
|
+
p.requires_grad = False
|
|
1262
1356
|
if isinstance(p, QTensor):
|
|
1263
1357
|
# # fix quanto bug (seems to have been fixed)
|
|
1264
1358
|
# if not modelPinned and p._scale.dtype == torch.float32:
|
|
1265
1359
|
# p._scale = p._scale.to(torch.bfloat16)
|
|
1266
|
-
|
|
1267
|
-
|
|
1360
|
+
if p._qtype == qint4:
|
|
1361
|
+
if hasattr(p,"_scale_shift"):
|
|
1362
|
+
current_model_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
|
|
1363
|
+
else:
|
|
1364
|
+
current_model_size += torch.numel(p._scale) * p._shift.element_size() + torch.numel(p._scale) * p._shift.element_size()
|
|
1365
|
+
|
|
1366
|
+
current_model_size += torch.numel(p._data._data) * p._data._data.element_size()
|
|
1367
|
+
|
|
1368
|
+
else:
|
|
1369
|
+
current_model_size += torch.numel(p._scale) * p._scale.element_size()
|
|
1370
|
+
current_model_size += torch.numel(p._data) * p._data.element_size()
|
|
1371
|
+
|
|
1268
1372
|
else:
|
|
1269
1373
|
if p.data.dtype == torch.float32:
|
|
1270
1374
|
# convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
|
|
@@ -1272,7 +1376,7 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
|
|
|
1272
1376
|
current_model_size += torch.numel(p.data) * p.data.element_size()
|
|
1273
1377
|
|
|
1274
1378
|
for b in current_model.buffers():
|
|
1275
|
-
if b.data.dtype == torch.float32:
|
|
1379
|
+
if b.data.dtype == torch.float32:
|
|
1276
1380
|
# convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
|
|
1277
1381
|
b.data = b.data.to(torch.bfloat16)
|
|
1278
1382
|
current_model_size += torch.numel(b.data) * b.data.element_size()
|
|
@@ -1298,22 +1402,21 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
|
|
|
1298
1402
|
# Hook forward methods of modules
|
|
1299
1403
|
for model_id in models:
|
|
1300
1404
|
current_model: torch.nn.Module = models[model_id]
|
|
1301
|
-
current_budget = model_budgets[model_id]
|
|
1302
|
-
current_size = 0
|
|
1303
|
-
cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
|
|
1304
|
-
self.loaded_blocks[model_id] = None
|
|
1305
1405
|
towers_names, towers_modules = _detect_main_towers(current_model)
|
|
1306
|
-
towers_names = [n +"." for n in towers_names]
|
|
1307
1406
|
if self.verboseLevel>=2 and len(towers_names)>0:
|
|
1308
1407
|
print(f"Potential iterative blocks found in model '{model_id}':{towers_names}")
|
|
1309
1408
|
# compile main iterative modules stacks ("towers")
|
|
1310
|
-
|
|
1311
|
-
|
|
1409
|
+
compilationInThisOne = compileAllModels or model_id in modelsToCompile
|
|
1410
|
+
if compilationInThisOne:
|
|
1312
1411
|
if self.verboseLevel>=1:
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1412
|
+
if len(towers_modules)>0:
|
|
1413
|
+
print(f"Pytorch compilation of model '{model_id}' is scheduled.")
|
|
1414
|
+
else:
|
|
1415
|
+
print(f"Pytorch compilation of model '{model_id}' is not yet supported.")
|
|
1416
|
+
|
|
1417
|
+
for submodel in towers_modules:
|
|
1418
|
+
# for submodel in tower:
|
|
1419
|
+
submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode="default" ) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
|
|
1317
1420
|
#dynamic=True,
|
|
1318
1421
|
|
|
1319
1422
|
if pinAllModels or model_id in modelsToPin:
|
|
@@ -1323,6 +1426,11 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
|
|
|
1323
1426
|
else:
|
|
1324
1427
|
_pin_to_memory(current_model, model_id, partialPinning= partialPinning, perc_reserved_mem_max=perc_reserved_mem_max, verboseLevel=verboseLevel)
|
|
1325
1428
|
|
|
1429
|
+
current_budget = model_budgets[model_id]
|
|
1430
|
+
current_size = 0
|
|
1431
|
+
cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
|
|
1432
|
+
self.loaded_blocks[model_id] = None
|
|
1433
|
+
|
|
1326
1434
|
for submodule_name, submodule in current_model.named_modules():
|
|
1327
1435
|
# create a fake 'accelerate' parameter so that the _execution_device property returns always "cuda"
|
|
1328
1436
|
# (it is queried in many pipelines even if offloading is not properly implemented)
|
|
@@ -1331,44 +1439,43 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
|
|
|
1331
1439
|
|
|
1332
1440
|
if submodule_name=='':
|
|
1333
1441
|
continue
|
|
1334
|
-
|
|
1442
|
+
|
|
1335
1443
|
if current_budget > 0:
|
|
1336
|
-
if
|
|
1337
|
-
if cur_blocks_prefix
|
|
1338
|
-
|
|
1444
|
+
if cur_blocks_prefix != None:
|
|
1445
|
+
if submodule_name.startswith(cur_blocks_prefix):
|
|
1446
|
+
depth_prefix = cur_blocks_prefix.split(".")
|
|
1447
|
+
depth_name = submodule_name.split(".")
|
|
1448
|
+
level = depth_name[len(depth_prefix)-1]
|
|
1449
|
+
pre , num = _extract_num_from_str(level)
|
|
1450
|
+
if num != cur_blocks_seq and (cur_blocks_seq == -1 or current_size > current_budget):
|
|
1451
|
+
prev_blocks_name = cur_blocks_name
|
|
1452
|
+
cur_blocks_name = cur_blocks_prefix + str(num)
|
|
1453
|
+
# print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
|
|
1454
|
+
cur_blocks_seq = num
|
|
1339
1455
|
else:
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
cur_blocks_name = cur_blocks_prefix + str(num)
|
|
1353
|
-
# print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
|
|
1354
|
-
cur_blocks_seq = num
|
|
1355
|
-
else:
|
|
1356
|
-
cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
|
|
1357
|
-
|
|
1456
|
+
cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
|
|
1457
|
+
|
|
1458
|
+
if cur_blocks_prefix == None:
|
|
1459
|
+
pre , num = _extract_num_from_str(submodule_name)
|
|
1460
|
+
if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
|
1461
|
+
cur_blocks_prefix, prev_blocks_name, cur_blocks_seq = pre + ".", None, -1
|
|
1462
|
+
elif num >=0:
|
|
1463
|
+
cur_blocks_prefix, prev_blocks_name, cur_blocks_seq = pre, None, num
|
|
1464
|
+
cur_blocks_name = submodule_name
|
|
1465
|
+
# print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
|
|
1466
|
+
|
|
1467
|
+
|
|
1358
1468
|
if hasattr(submodule, "forward"):
|
|
1359
1469
|
submodule_method = getattr(submodule, "forward")
|
|
1360
1470
|
if callable(submodule_method):
|
|
1361
1471
|
if len(submodule_name.split("."))==1:
|
|
1362
1472
|
self.hook_change_module(submodule, current_model, model_id, submodule_name, submodule_method)
|
|
1363
|
-
elif
|
|
1364
|
-
self.
|
|
1473
|
+
elif compilationInThisOne and submodule in towers_modules:
|
|
1474
|
+
self.hook_preload_blocks_for_compilation(submodule, model_id, cur_blocks_name, context = submodule_name )
|
|
1365
1475
|
else:
|
|
1366
1476
|
self.hook_check_empty_cache_needed(submodule, model_id, cur_blocks_name, submodule_method, context = submodule_name )
|
|
1367
1477
|
|
|
1368
|
-
|
|
1369
|
-
current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
|
|
1370
|
-
|
|
1371
|
-
|
|
1478
|
+
current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
|
|
1372
1479
|
|
|
1373
1480
|
|
|
1374
1481
|
if self.verboseLevel >=2:
|
|
@@ -1406,7 +1513,7 @@ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.Ve
|
|
|
1406
1513
|
modules= modules.components
|
|
1407
1514
|
|
|
1408
1515
|
modules = {k: _remove_model_wrapper(v) for k, v in modules.items() if isinstance(v, torch.nn.Module)}
|
|
1409
|
-
module_names = {k: v
|
|
1516
|
+
module_names = {k: _get_module_name(v) for k, v in modules.items() }
|
|
1410
1517
|
|
|
1411
1518
|
default_extraModelsToQuantize = []
|
|
1412
1519
|
quantizeTransformer = True
|
|
@@ -1414,11 +1521,12 @@ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.Ve
|
|
|
1414
1521
|
models_to_scan = ("text_encoder", "text_encoder_2")
|
|
1415
1522
|
candidates_to_quantize = ("t5", "llama", "llm")
|
|
1416
1523
|
for model_id in models_to_scan:
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1524
|
+
if model_id in module_names:
|
|
1525
|
+
name = module_names[model_id]
|
|
1526
|
+
for candidate in candidates_to_quantize:
|
|
1527
|
+
if candidate in name:
|
|
1528
|
+
default_extraModelsToQuantize.append(model_id)
|
|
1529
|
+
break
|
|
1422
1530
|
|
|
1423
1531
|
|
|
1424
1532
|
# transformer (video or image generator) should be as small as possible not to occupy space that could be used by actual image data
|
|
@@ -1427,6 +1535,7 @@ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.Ve
|
|
|
1427
1535
|
default_budgets = { "transformer" : 600 , "text_encoder": 3000, "text_encoder_2": 3000 }
|
|
1428
1536
|
extraModelsToQuantize = None
|
|
1429
1537
|
asyncTransfers = True
|
|
1538
|
+
budgets = None
|
|
1430
1539
|
|
|
1431
1540
|
if profile_no == profile_type.HighRAM_HighVRAM:
|
|
1432
1541
|
pinnedMemory= True
|