mmgp 3.5.7__py3-none-any.whl → 3.6.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mmgp/fp8_quanto_bridge.py +645 -0
- mmgp/fp8_quanto_bridge_old.py +498 -0
- mmgp/offload.py +1038 -248
- mmgp/quant_router.py +518 -0
- mmgp/quanto_int8_cuda.py +97 -0
- mmgp/quanto_int8_inject.py +335 -0
- mmgp/safetensors2.py +57 -10
- {mmgp-3.5.7.dist-info → mmgp-3.6.11.dist-info}/METADATA +2 -2
- mmgp-3.6.11.dist-info/RECORD +14 -0
- {mmgp-3.5.7.dist-info → mmgp-3.6.11.dist-info}/licenses/LICENSE.md +1 -1
- mmgp-3.5.7.dist-info/RECORD +0 -9
- {mmgp-3.5.7.dist-info → mmgp-3.6.11.dist-info}/WHEEL +0 -0
- {mmgp-3.5.7.dist-info → mmgp-3.6.11.dist-info}/top_level.txt +0 -0
mmgp/offload.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# ------------------ Memory Management 3.
|
|
1
|
+
# ------------------ Memory Management 3.6.11 for the GPU Poor by DeepBeepMeep (mmgp)------------------
|
|
2
2
|
#
|
|
3
3
|
# This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
|
|
4
4
|
# This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
|
|
@@ -60,16 +60,23 @@ import functools
|
|
|
60
60
|
import sys
|
|
61
61
|
import os
|
|
62
62
|
import json
|
|
63
|
+
import inspect
|
|
63
64
|
import psutil
|
|
64
65
|
import builtins
|
|
65
66
|
from accelerate import init_empty_weights
|
|
66
|
-
|
|
67
|
+
from functools import wraps
|
|
68
|
+
import functools
|
|
69
|
+
import types
|
|
67
70
|
|
|
68
71
|
from mmgp import safetensors2
|
|
69
72
|
from mmgp import profile_type
|
|
70
|
-
|
|
73
|
+
from .quant_router import (
|
|
74
|
+
apply_pre_quantization,
|
|
75
|
+
cache_quantization_for_file,
|
|
76
|
+
detect_and_convert,
|
|
77
|
+
detect_safetensors_format,
|
|
78
|
+
)
|
|
71
79
|
from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QLinear, QTensor, quantize_module, register_qmodule
|
|
72
|
-
|
|
73
80
|
# support for Embedding module quantization that is not supported by default by quanto
|
|
74
81
|
@register_qmodule(torch.nn.Embedding)
|
|
75
82
|
class QEmbedding(QModuleMixin, torch.nn.Embedding):
|
|
@@ -83,8 +90,36 @@ class QEmbedding(QModuleMixin, torch.nn.Embedding):
|
|
|
83
90
|
return torch.nn.functional.embedding( input, self.qweight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse )
|
|
84
91
|
|
|
85
92
|
|
|
93
|
+
|
|
94
|
+
def cudacontext(device):
|
|
95
|
+
def decorator(func):
|
|
96
|
+
@wraps(func)
|
|
97
|
+
def wrapper(*args, **kwargs):
|
|
98
|
+
with torch.device(device):
|
|
99
|
+
return func(*args, **kwargs)
|
|
100
|
+
return wrapper
|
|
101
|
+
return decorator
|
|
102
|
+
|
|
103
|
+
|
|
86
104
|
shared_state = {}
|
|
87
105
|
|
|
106
|
+
def get_cache(cache_name):
|
|
107
|
+
all_cache = shared_state.get("_cache", None)
|
|
108
|
+
if all_cache is None:
|
|
109
|
+
all_cache = {}
|
|
110
|
+
shared_state["_cache"]= all_cache
|
|
111
|
+
cache = all_cache.get(cache_name, None)
|
|
112
|
+
if cache is None:
|
|
113
|
+
cache = {}
|
|
114
|
+
all_cache[cache_name] = cache
|
|
115
|
+
return cache
|
|
116
|
+
|
|
117
|
+
def clear_caches():
|
|
118
|
+
all_cache = shared_state.get("_cache", None)
|
|
119
|
+
if all_cache is not None:
|
|
120
|
+
all_cache.clear()
|
|
121
|
+
|
|
122
|
+
|
|
88
123
|
mmm = safetensors2.mmm
|
|
89
124
|
|
|
90
125
|
default_verboseLevel = 1
|
|
@@ -277,23 +312,112 @@ def _safetensors_load_file(file_path, writable_tensors = True):
|
|
|
277
312
|
|
|
278
313
|
def _force_load_buffer(p):
|
|
279
314
|
# To do : check if buffer was persistent and transfer state, or maybe swap keep already this property ?
|
|
280
|
-
q = torch.nn.Buffer(p
|
|
315
|
+
q = torch.nn.Buffer(p.clone())
|
|
281
316
|
torch.utils.swap_tensors(p, q)
|
|
282
317
|
del q
|
|
283
318
|
|
|
284
319
|
def _force_load_parameter(p):
|
|
285
|
-
q = torch.nn.Parameter(p
|
|
320
|
+
q = torch.nn.Parameter(p.clone())
|
|
286
321
|
torch.utils.swap_tensors(p, q)
|
|
287
322
|
del q
|
|
288
323
|
|
|
289
|
-
def
|
|
290
|
-
if
|
|
291
|
-
|
|
292
|
-
|
|
324
|
+
def _unwrap_quantized_tensor(tensor):
|
|
325
|
+
if hasattr(tensor, "_data") and torch.is_tensor(tensor._data):
|
|
326
|
+
return tensor._data
|
|
327
|
+
return tensor
|
|
328
|
+
|
|
329
|
+
def _qtensor_get_quantized_subtensors(self):
|
|
330
|
+
subtensors = []
|
|
331
|
+
if getattr(self, "_qtype", None) == qint4:
|
|
332
|
+
data = _unwrap_quantized_tensor(self._data)
|
|
333
|
+
subtensors.append(("data", data))
|
|
334
|
+
if hasattr(self, "_scale_shift") and self._scale_shift is not None:
|
|
335
|
+
subtensors.append(("scale_shift", self._scale_shift))
|
|
336
|
+
else:
|
|
337
|
+
if hasattr(self, "_scale") and self._scale is not None:
|
|
338
|
+
subtensors.append(("scale", self._scale))
|
|
339
|
+
if hasattr(self, "_shift") and self._shift is not None:
|
|
340
|
+
subtensors.append(("shift", self._shift))
|
|
341
|
+
return subtensors
|
|
342
|
+
|
|
343
|
+
if hasattr(self, "_data"):
|
|
344
|
+
data = _unwrap_quantized_tensor(self._data)
|
|
345
|
+
subtensors.append(("data", data))
|
|
346
|
+
if hasattr(self, "_scale") and self._scale is not None:
|
|
347
|
+
subtensors.append(("scale", self._scale))
|
|
348
|
+
return subtensors
|
|
349
|
+
|
|
350
|
+
def _qtensor_set_quantized_subtensors(self, sub_tensors):
|
|
351
|
+
if isinstance(sub_tensors, dict):
|
|
352
|
+
sub_map = sub_tensors
|
|
353
|
+
else:
|
|
354
|
+
sub_map = {name: tensor for name, tensor in sub_tensors}
|
|
355
|
+
|
|
356
|
+
data = sub_map.get("data", None)
|
|
357
|
+
if data is not None:
|
|
358
|
+
if hasattr(self, "_data") and hasattr(self._data, "_data") and torch.is_tensor(self._data._data):
|
|
359
|
+
self._data._data = data
|
|
360
|
+
else:
|
|
361
|
+
self._data = data
|
|
362
|
+
|
|
363
|
+
if getattr(self, "_qtype", None) == qint4:
|
|
364
|
+
if "scale_shift" in sub_map and sub_map["scale_shift"] is not None:
|
|
365
|
+
self._scale_shift = sub_map["scale_shift"]
|
|
293
366
|
else:
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
367
|
+
if "scale" in sub_map and sub_map["scale"] is not None:
|
|
368
|
+
self._scale = sub_map["scale"]
|
|
369
|
+
if "shift" in sub_map and sub_map["shift"] is not None:
|
|
370
|
+
self._shift = sub_map["shift"]
|
|
371
|
+
else:
|
|
372
|
+
if "scale" in sub_map and sub_map["scale"] is not None:
|
|
373
|
+
self._scale = sub_map["scale"]
|
|
374
|
+
|
|
375
|
+
if not hasattr(QTensor, "get_quantized_subtensors"):
|
|
376
|
+
QTensor.get_quantized_subtensors = _qtensor_get_quantized_subtensors
|
|
377
|
+
if not hasattr(QTensor, "set_quantized_subtensors"):
|
|
378
|
+
QTensor.set_quantized_subtensors = _qtensor_set_quantized_subtensors
|
|
379
|
+
|
|
380
|
+
def _get_quantized_subtensors(p):
|
|
381
|
+
getter = getattr(p, "get_quantized_subtensors", None)
|
|
382
|
+
if getter is None:
|
|
383
|
+
return None
|
|
384
|
+
sub_tensors = getter()
|
|
385
|
+
if not sub_tensors:
|
|
386
|
+
return None
|
|
387
|
+
if isinstance(sub_tensors, dict):
|
|
388
|
+
sub_tensors = list(sub_tensors.items())
|
|
389
|
+
out = []
|
|
390
|
+
for name, tensor in sub_tensors:
|
|
391
|
+
if tensor is None:
|
|
392
|
+
continue
|
|
393
|
+
if torch.is_tensor(tensor):
|
|
394
|
+
out.append((name, tensor))
|
|
395
|
+
return out if out else None
|
|
396
|
+
|
|
397
|
+
def _set_quantized_subtensors(p, sub_tensors):
|
|
398
|
+
setter = getattr(p, "set_quantized_subtensors", None)
|
|
399
|
+
if setter is None:
|
|
400
|
+
return False
|
|
401
|
+
setter(sub_tensors)
|
|
402
|
+
return True
|
|
403
|
+
|
|
404
|
+
def _subtensors_nbytes(sub_tensors):
|
|
405
|
+
return sum(torch.numel(t) * t.element_size() for _, t in sub_tensors)
|
|
406
|
+
|
|
407
|
+
def _subtensors_itemsize(sub_tensors, fallback):
|
|
408
|
+
for _, t in sub_tensors:
|
|
409
|
+
return t.element_size()
|
|
410
|
+
return fallback
|
|
411
|
+
|
|
412
|
+
def _get_tensor_ref(p):
|
|
413
|
+
sub_tensors = _get_quantized_subtensors(p)
|
|
414
|
+
if sub_tensors:
|
|
415
|
+
for _, t in sub_tensors:
|
|
416
|
+
ref = t.data_ptr()
|
|
417
|
+
del sub_tensors
|
|
418
|
+
return ref
|
|
419
|
+
del sub_tensors
|
|
420
|
+
return p.data_ptr()
|
|
297
421
|
|
|
298
422
|
|
|
299
423
|
BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
|
|
@@ -516,25 +640,18 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
|
|
|
516
640
|
tied_weights_last = f"{match_name} <-> {n}"
|
|
517
641
|
tied_weights[n] = match_name
|
|
518
642
|
else:
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
else:
|
|
527
|
-
length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale) * p._scale.element_size() + torch.numel(p._shift) * p._shift.element_size()
|
|
528
|
-
else:
|
|
529
|
-
length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
|
|
530
|
-
if p._data.is_pinned():
|
|
531
|
-
params_dict[n] = (None, False)
|
|
532
|
-
continue
|
|
643
|
+
sub_tensors = _get_quantized_subtensors(p)
|
|
644
|
+
if sub_tensors:
|
|
645
|
+
if builtins.all(t.is_pinned() for _, t in sub_tensors):
|
|
646
|
+
params_dict[n] = (None, False)
|
|
647
|
+
del sub_tensors
|
|
648
|
+
continue
|
|
649
|
+
length = _subtensors_nbytes(sub_tensors)
|
|
533
650
|
else:
|
|
534
651
|
if p.data.is_pinned():
|
|
535
652
|
params_dict[n] = (None, False)
|
|
536
653
|
continue
|
|
537
|
-
length = torch.numel(p.data) * p.data.element_size()
|
|
654
|
+
length = torch.numel(p.data) * p.data.element_size()
|
|
538
655
|
|
|
539
656
|
ref_cache[ref] = (n, length)
|
|
540
657
|
if current_big_tensor_size + length > big_tensor_size and current_big_tensor_size !=0 :
|
|
@@ -542,8 +659,11 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
|
|
|
542
659
|
current_big_tensor_size = 0
|
|
543
660
|
big_tensor_no += 1
|
|
544
661
|
|
|
545
|
-
|
|
546
|
-
|
|
662
|
+
if sub_tensors:
|
|
663
|
+
itemsize = _subtensors_itemsize(sub_tensors, p.data.dtype.itemsize)
|
|
664
|
+
del sub_tensors
|
|
665
|
+
else:
|
|
666
|
+
itemsize = p.data.dtype.itemsize
|
|
547
667
|
if current_big_tensor_size % itemsize:
|
|
548
668
|
current_big_tensor_size += itemsize - current_big_tensor_size % itemsize
|
|
549
669
|
tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
|
|
@@ -580,15 +700,11 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
|
|
|
580
700
|
q_name = tied_weights.get(n,None)
|
|
581
701
|
if q_name != None:
|
|
582
702
|
q , _ = params_dict[q_name]
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
else:
|
|
589
|
-
p._data = q._data
|
|
590
|
-
p._scale = q._scale
|
|
591
|
-
assert p._data.is_pinned()
|
|
703
|
+
sub_tensors = _get_quantized_subtensors(q)
|
|
704
|
+
if sub_tensors:
|
|
705
|
+
sub_map = {name: tensor for name, tensor in sub_tensors}
|
|
706
|
+
_set_quantized_subtensors(p, sub_map)
|
|
707
|
+
del sub_map, sub_tensors
|
|
592
708
|
else:
|
|
593
709
|
p.data = q.data
|
|
594
710
|
assert p.data.is_pinned()
|
|
@@ -618,27 +734,21 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
|
|
|
618
734
|
total += size
|
|
619
735
|
|
|
620
736
|
current_big_tensor = big_tensors[big_tensor_no]
|
|
737
|
+
|
|
621
738
|
if is_buffer :
|
|
622
739
|
_force_load_buffer(p) # otherwise potential memory leak
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
length3 = torch.numel(p._shift) * p._shift.element_size()
|
|
634
|
-
p._shift = _move_to_pinned_tensor(p._shift, current_big_tensor, offset + length1 + length2, length3)
|
|
635
|
-
else:
|
|
636
|
-
length1 = torch.numel(p._data) * p._data.element_size()
|
|
637
|
-
p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
|
|
638
|
-
length2 = torch.numel(p._scale) * p._scale.element_size()
|
|
639
|
-
p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
|
|
740
|
+
sub_tensors = _get_quantized_subtensors(p)
|
|
741
|
+
if sub_tensors:
|
|
742
|
+
sub_offset = offset
|
|
743
|
+
new_subs = {}
|
|
744
|
+
for name, tensor in sub_tensors:
|
|
745
|
+
length = torch.numel(tensor) * tensor.element_size()
|
|
746
|
+
new_subs[name] = _move_to_pinned_tensor(tensor, current_big_tensor, sub_offset, length)
|
|
747
|
+
sub_offset += length
|
|
748
|
+
_set_quantized_subtensors(p, new_subs)
|
|
749
|
+
del new_subs, sub_tensors
|
|
640
750
|
else:
|
|
641
|
-
length = torch.numel(p.data) * p.data.element_size()
|
|
751
|
+
length = torch.numel(p.data) * p.data.element_size()
|
|
642
752
|
p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
|
|
643
753
|
|
|
644
754
|
tensor_no += 1
|
|
@@ -666,18 +776,22 @@ def _welcome():
|
|
|
666
776
|
if welcome_displayed:
|
|
667
777
|
return
|
|
668
778
|
welcome_displayed = True
|
|
669
|
-
print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.
|
|
779
|
+
print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.6.11) by DeepBeepMeep ************{ENDC}{UNBOLD}")
|
|
670
780
|
|
|
671
781
|
def change_dtype(model, new_dtype, exclude_buffers = False):
|
|
672
782
|
for submodule_name, submodule in model.named_modules():
|
|
673
783
|
if hasattr(submodule, "_lock_dtype"):
|
|
674
784
|
continue
|
|
675
785
|
for n, p in submodule.named_parameters(recurse = False):
|
|
786
|
+
if isinstance(p, QTensor):
|
|
787
|
+
continue
|
|
676
788
|
if p.data.dtype != new_dtype:
|
|
677
789
|
p.data = p.data.to(new_dtype)
|
|
678
790
|
|
|
679
791
|
if not exclude_buffers:
|
|
680
792
|
for p in submodule.buffers(recurse=False):
|
|
793
|
+
if isinstance(p, QTensor):
|
|
794
|
+
continue
|
|
681
795
|
if p.data.dtype != new_dtype:
|
|
682
796
|
p.data = p.data.to(new_dtype)
|
|
683
797
|
|
|
@@ -751,7 +865,7 @@ def _quantize_submodule(
|
|
|
751
865
|
setattr(module, name, None)
|
|
752
866
|
del param
|
|
753
867
|
|
|
754
|
-
def _requantize(model: torch.nn.Module, state_dict: dict, quantization_map: dict):
|
|
868
|
+
def _requantize(model: torch.nn.Module, state_dict: dict, quantization_map: dict, default_dtype=None):
|
|
755
869
|
# change dtype of current meta model parameters because 'requantize' won't update the dtype on non quantized parameters
|
|
756
870
|
for k, p in model.named_parameters():
|
|
757
871
|
if not k in quantization_map and k in state_dict:
|
|
@@ -770,6 +884,11 @@ def _requantize(model: torch.nn.Module, state_dict: dict, quantization_map: dict
|
|
|
770
884
|
if activations == "none":
|
|
771
885
|
activations = None
|
|
772
886
|
_quantize_submodule(model, name, m, weights=weights, activations=activations)
|
|
887
|
+
if default_dtype is not None:
|
|
888
|
+
new_module = model.get_submodule(name)
|
|
889
|
+
setter = getattr(new_module, "set_default_dtype", None)
|
|
890
|
+
if callable(setter):
|
|
891
|
+
setter(default_dtype)
|
|
773
892
|
|
|
774
893
|
model._quanto_map = quantization_map
|
|
775
894
|
|
|
@@ -803,6 +922,7 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
|
|
|
803
922
|
|
|
804
923
|
cache_ref = {}
|
|
805
924
|
tied_weights= {}
|
|
925
|
+
reversed_tied_weights= {}
|
|
806
926
|
|
|
807
927
|
for submodule_name, submodule in model_to_quantize.named_modules():
|
|
808
928
|
if isinstance(submodule, QModuleMixin):
|
|
@@ -815,7 +935,9 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
|
|
|
815
935
|
ref = _get_tensor_ref(p)
|
|
816
936
|
match = cache_ref.get(ref, None)
|
|
817
937
|
if match != None:
|
|
818
|
-
tied_weights[submodule_name]= (n, ) + match
|
|
938
|
+
tied_weights[submodule_name]= (n, ) + match
|
|
939
|
+
entries = reversed_tied_weights.get( match, [])
|
|
940
|
+
reversed_tied_weights[match] = entries + [ (p, submodule_name,n)]
|
|
819
941
|
else:
|
|
820
942
|
cache_ref[ref] = (submodule_name, n)
|
|
821
943
|
size += torch.numel(p.data) * sizeofhalffloat
|
|
@@ -883,6 +1005,7 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
|
|
|
883
1005
|
# force to read non quantized parameters so that their lazy tensors and corresponding mmap are released
|
|
884
1006
|
# otherwise we may end up keeping in memory both the quantized and the non quantize model
|
|
885
1007
|
named_modules = {n:m for n,m in model_to_quantize.named_modules()}
|
|
1008
|
+
|
|
886
1009
|
for module_name, module in named_modules.items():
|
|
887
1010
|
# do not read quantized weights (detected them directly or behind an adapter)
|
|
888
1011
|
if isinstance(module, QModuleMixin) or hasattr(module, "base_layer") and isinstance(module.base_layer, QModuleMixin):
|
|
@@ -891,12 +1014,18 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
|
|
|
891
1014
|
else:
|
|
892
1015
|
tied_w = tied_weights.get(module_name, None)
|
|
893
1016
|
for n, p in module.named_parameters(recurse = False):
|
|
1017
|
+
|
|
894
1018
|
if tied_w != None and n == tied_w[0]:
|
|
895
1019
|
if isinstance( named_modules[tied_w[1]], QModuleMixin) :
|
|
896
1020
|
setattr(module, n, None) # release refs of tied weights if source is going to be quantized
|
|
897
1021
|
# otherwise don't force load as it will be loaded in the source anyway
|
|
898
1022
|
else:
|
|
899
1023
|
_force_load_parameter(p)
|
|
1024
|
+
entries = reversed_tied_weights.get( (module_name, n), [])
|
|
1025
|
+
for tied_weight, tied_module_name, tied_weight_name in entries:
|
|
1026
|
+
if n == tied_weight_name:
|
|
1027
|
+
tied_weight.data = p.data
|
|
1028
|
+
|
|
900
1029
|
del p # del p if not it will still contain a ref to a tensor when leaving the loop
|
|
901
1030
|
for b in module.buffers(recurse = False):
|
|
902
1031
|
_force_load_buffer(b)
|
|
@@ -927,38 +1056,340 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
|
|
|
927
1056
|
|
|
928
1057
|
return True
|
|
929
1058
|
|
|
930
|
-
def
|
|
931
|
-
|
|
1059
|
+
def _as_field_tuple(value):
|
|
1060
|
+
if not value:
|
|
1061
|
+
return ()
|
|
1062
|
+
if isinstance(value, str):
|
|
1063
|
+
return (value,)
|
|
1064
|
+
return tuple(value)
|
|
1065
|
+
|
|
1066
|
+
|
|
1067
|
+
def _get_split_handler(info, field, default_handlers):
|
|
1068
|
+
handlers = info.get("split_handlers") or info.get("field_handlers") or {}
|
|
1069
|
+
if handlers:
|
|
1070
|
+
handler = handlers.get(field)
|
|
1071
|
+
if handler is not None:
|
|
1072
|
+
return handler
|
|
1073
|
+
if default_handlers:
|
|
1074
|
+
return default_handlers.get(field)
|
|
1075
|
+
return None
|
|
1076
|
+
|
|
1077
|
+
|
|
1078
|
+
def _get_split_base_fields(info, split_fields):
|
|
1079
|
+
base_fields = _as_field_tuple(info.get("base_fields") or info.get("base_field"))
|
|
1080
|
+
if base_fields:
|
|
1081
|
+
return base_fields
|
|
1082
|
+
if split_fields:
|
|
1083
|
+
return (next(iter(split_fields.keys())),)
|
|
1084
|
+
return ()
|
|
1085
|
+
|
|
1086
|
+
|
|
1087
|
+
def _merge_share_fields(info, share_fields):
|
|
1088
|
+
info_fields = _as_field_tuple(info.get("share_fields") or info.get("shared_fields"))
|
|
1089
|
+
return tuple(sorted(set(info_fields).union(_as_field_tuple(share_fields))))
|
|
1090
|
+
|
|
1091
|
+
|
|
1092
|
+
def _call_split_handler(handler, *, src, dim, split_sizes, context):
|
|
1093
|
+
if handler is None:
|
|
1094
|
+
return None
|
|
1095
|
+
try:
|
|
1096
|
+
chunks = handler(src=src, dim=dim, split_sizes=split_sizes, context=context)
|
|
1097
|
+
except Exception:
|
|
1098
|
+
return None
|
|
1099
|
+
if not isinstance(chunks, (list, tuple)) or len(chunks) != len(split_sizes):
|
|
1100
|
+
return None
|
|
1101
|
+
return chunks
|
|
1102
|
+
|
|
1103
|
+
|
|
1104
|
+
def _fill_sub_maps(sub_maps, name, value):
|
|
1105
|
+
for sub_map in sub_maps:
|
|
1106
|
+
sub_map[name] = value
|
|
1107
|
+
|
|
1108
|
+
|
|
1109
|
+
def sd_split_linear(
|
|
1110
|
+
state_dict,
|
|
1111
|
+
split_map,
|
|
1112
|
+
split_fields=None,
|
|
1113
|
+
share_fields=None,
|
|
1114
|
+
verboseLevel=1,
|
|
1115
|
+
split_handlers=None,
|
|
1116
|
+
):
|
|
1117
|
+
if not split_map:
|
|
1118
|
+
return state_dict
|
|
1119
|
+
split_fields = split_fields or {}
|
|
1120
|
+
share_fields = share_fields or ()
|
|
1121
|
+
split_handlers = split_handlers or {}
|
|
1122
|
+
base_fields_by_suffix = {
|
|
1123
|
+
suffix: _get_split_base_fields(info or {}, split_fields)
|
|
1124
|
+
for suffix, info in split_map.items()
|
|
1125
|
+
}
|
|
1126
|
+
def _skip(msg):
|
|
1127
|
+
if verboseLevel >= 2:
|
|
1128
|
+
print(f"[sd_split_linear] Skip {msg}")
|
|
1129
|
+
|
|
1130
|
+
bases = {}
|
|
1131
|
+
for key in state_dict.keys():
|
|
1132
|
+
for suffix, base_fields in base_fields_by_suffix.items():
|
|
1133
|
+
for base_field in base_fields:
|
|
1134
|
+
suffix_token = f"{suffix}.{base_field}"
|
|
1135
|
+
if not key.endswith(suffix_token):
|
|
1136
|
+
continue
|
|
1137
|
+
base = key[: -len("." + base_field)]
|
|
1138
|
+
if base.endswith(suffix):
|
|
1139
|
+
bases[base] = suffix
|
|
1140
|
+
break
|
|
1141
|
+
|
|
1142
|
+
if not bases:
|
|
1143
|
+
return state_dict
|
|
1144
|
+
|
|
1145
|
+
for base, suffix in bases.items():
|
|
1146
|
+
info = split_map.get(suffix) or {}
|
|
1147
|
+
mapped = info.get("mapped_modules") or info.get("mapped_suffixes") or info.get("mapped") or []
|
|
1148
|
+
if not mapped:
|
|
1149
|
+
continue
|
|
1150
|
+
|
|
1151
|
+
base_fields = base_fields_by_suffix.get(suffix) or _get_split_base_fields(info, split_fields)
|
|
1152
|
+
size_field = info.get("size_field") or (base_fields[0] if base_fields else None)
|
|
1153
|
+
size_tensor = state_dict.get(base + "." + size_field) if size_field else None
|
|
1154
|
+
split_dim = info.get("split_dim", 0)
|
|
1155
|
+
split_sizes = list(info.get("split_sizes") or [])
|
|
1156
|
+
if not split_sizes:
|
|
1157
|
+
if size_tensor is None:
|
|
1158
|
+
continue
|
|
1159
|
+
if size_tensor.dim() <= split_dim:
|
|
1160
|
+
_skip(f"{base}: dim={size_tensor.dim()} split_dim={split_dim}")
|
|
1161
|
+
continue
|
|
1162
|
+
out_dim = size_tensor.size(split_dim)
|
|
1163
|
+
if out_dim % len(mapped) != 0:
|
|
1164
|
+
_skip(f"{base}: out_dim={out_dim} not divisible by {len(mapped)}")
|
|
1165
|
+
continue
|
|
1166
|
+
split_sizes = [out_dim // len(mapped)] * len(mapped)
|
|
1167
|
+
elif None in split_sizes:
|
|
1168
|
+
if size_tensor is None:
|
|
1169
|
+
continue
|
|
1170
|
+
if size_tensor.dim() <= split_dim:
|
|
1171
|
+
_skip(f"{base}: dim={size_tensor.dim()} split_dim={split_dim}")
|
|
1172
|
+
continue
|
|
1173
|
+
known = sum(size for size in split_sizes if size is not None)
|
|
1174
|
+
none_count = split_sizes.count(None)
|
|
1175
|
+
remaining = size_tensor.size(split_dim) - known
|
|
1176
|
+
if remaining < 0 or remaining % none_count != 0:
|
|
1177
|
+
_skip(f"{base}: cannot resolve split sizes")
|
|
1178
|
+
continue
|
|
1179
|
+
fill = remaining // none_count
|
|
1180
|
+
split_sizes = [fill if size is None else size for size in split_sizes]
|
|
1181
|
+
|
|
1182
|
+
total = sum(split_sizes)
|
|
1183
|
+
prefix = base[: -len(suffix)]
|
|
1184
|
+
target_bases = [prefix + name for name in mapped]
|
|
1185
|
+
added = 0
|
|
1186
|
+
|
|
1187
|
+
field_tensors = {
|
|
1188
|
+
field: state_dict.get(base + "." + field)
|
|
1189
|
+
for field in set(split_fields.keys()).union(share_fields)
|
|
1190
|
+
}
|
|
1191
|
+
base_ctx = {
|
|
1192
|
+
"state_dict": state_dict,
|
|
1193
|
+
"base": base,
|
|
1194
|
+
"suffix": suffix,
|
|
1195
|
+
"split_sizes": split_sizes,
|
|
1196
|
+
"total": total,
|
|
1197
|
+
"mapped": mapped,
|
|
1198
|
+
"target_bases": target_bases,
|
|
1199
|
+
"verboseLevel": verboseLevel,
|
|
1200
|
+
"split_fields": split_fields,
|
|
1201
|
+
"share_fields": share_fields,
|
|
1202
|
+
"field_tensors": field_tensors,
|
|
1203
|
+
"size_field": size_field,
|
|
1204
|
+
"size_tensor": size_tensor,
|
|
1205
|
+
"split_dim": split_dim,
|
|
1206
|
+
"info": info,
|
|
1207
|
+
}
|
|
1208
|
+
fields_iter = list(split_fields.items()) + [(field, None) for field in share_fields]
|
|
1209
|
+
for field, dim in fields_iter:
|
|
1210
|
+
src = field_tensors.get(field)
|
|
1211
|
+
if src is None:
|
|
1212
|
+
continue
|
|
1213
|
+
if dim is None:
|
|
1214
|
+
for target_base in target_bases:
|
|
1215
|
+
dest_key = target_base + "." + field
|
|
1216
|
+
if dest_key not in state_dict:
|
|
1217
|
+
state_dict[dest_key] = src
|
|
1218
|
+
added += 1
|
|
1219
|
+
continue
|
|
1220
|
+
if src.dim() <= dim:
|
|
1221
|
+
_skip(f"{base}.{field}: dim={src.dim()} split_dim={dim}")
|
|
1222
|
+
continue
|
|
1223
|
+
if src.size(dim) != total:
|
|
1224
|
+
_skip(f"{base}.{field}: size({dim})={src.size(dim)} expected={total}")
|
|
1225
|
+
continue
|
|
1226
|
+
handler = _get_split_handler(info, field, split_handlers)
|
|
1227
|
+
chunks = _call_split_handler(
|
|
1228
|
+
handler,
|
|
1229
|
+
src=src,
|
|
1230
|
+
dim=dim,
|
|
1231
|
+
split_sizes=split_sizes,
|
|
1232
|
+
context=dict(base_ctx, field=field),
|
|
1233
|
+
)
|
|
1234
|
+
if chunks is None:
|
|
1235
|
+
chunks = torch.split(src, split_sizes, dim=dim)
|
|
1236
|
+
for target_base, chunk in zip(target_bases, chunks):
|
|
1237
|
+
if torch.is_tensor(chunk) and not chunk.is_contiguous():
|
|
1238
|
+
chunk = chunk.contiguous()
|
|
1239
|
+
dest_key = target_base + "." + field
|
|
1240
|
+
if dest_key not in state_dict:
|
|
1241
|
+
state_dict[dest_key] = chunk
|
|
1242
|
+
added += 1
|
|
1243
|
+
|
|
1244
|
+
if added:
|
|
1245
|
+
for field in list(split_fields.keys()) + list(share_fields):
|
|
1246
|
+
state_dict.pop(base + "." + field, None)
|
|
1247
|
+
if verboseLevel >= 2:
|
|
1248
|
+
print(f"[sd_split_linear] Split {base} -> {', '.join(mapped)}")
|
|
1249
|
+
|
|
1250
|
+
return state_dict
|
|
1251
|
+
|
|
1252
|
+
|
|
1253
|
+
def split_linear_modules(model, map, split_handlers=None, share_fields=None):
|
|
1254
|
+
from optimum.quanto import QModuleMixin
|
|
932
1255
|
from accelerate import init_empty_weights
|
|
933
1256
|
|
|
1257
|
+
split_handlers = split_handlers or {}
|
|
1258
|
+
share_fields = share_fields or ()
|
|
1259
|
+
|
|
934
1260
|
modules_dict = { k: m for k, m in model.named_modules()}
|
|
935
1261
|
for module_suffix, split_info in map.items():
|
|
936
1262
|
mapped_modules = split_info["mapped_modules"]
|
|
937
1263
|
split_sizes = split_info["split_sizes"]
|
|
1264
|
+
split_share_fields = _merge_share_fields(split_info, share_fields)
|
|
1265
|
+
split_dims = split_info.get("split_dims") or {}
|
|
938
1266
|
for k, module in modules_dict.items():
|
|
939
1267
|
if k.endswith("." + module_suffix):
|
|
940
1268
|
parent_module = modules_dict[k[:len(k)-len(module_suffix)-1]]
|
|
941
1269
|
weight = module.weight
|
|
942
1270
|
bias = getattr(module, "bias", None)
|
|
943
1271
|
if isinstance(module, QModuleMixin):
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
1272
|
+
out_features_total = weight.size(0)
|
|
1273
|
+
if sum(split_sizes) != out_features_total:
|
|
1274
|
+
raise ValueError(
|
|
1275
|
+
f"Split sizes {split_sizes} do not match out_features {out_features_total} for '{k}'."
|
|
1276
|
+
)
|
|
1277
|
+
in_features = weight.size(1)
|
|
1278
|
+
sub_biases = None
|
|
1279
|
+
if bias is not None and bias.dim() > 0 and bias.size(0) == out_features_total:
|
|
1280
|
+
sub_biases = torch.split(bias, split_sizes, dim=0)
|
|
1281
|
+
else:
|
|
1282
|
+
sub_biases = [bias] * len(split_sizes)
|
|
1283
|
+
|
|
1284
|
+
sub_tensors = _get_quantized_subtensors(weight)
|
|
1285
|
+
if not sub_tensors:
|
|
1286
|
+
raise ValueError(f"Unable to split quantized weight for '{k}'.")
|
|
1287
|
+
sub_maps = [dict() for _ in split_sizes]
|
|
1288
|
+
field_tensors = {name: tensor for name, tensor in sub_tensors}
|
|
1289
|
+
base_ctx = {
|
|
1290
|
+
"module": module,
|
|
1291
|
+
"module_name": k,
|
|
1292
|
+
"module_suffix": module_suffix,
|
|
1293
|
+
"mapped_modules": mapped_modules,
|
|
1294
|
+
"split_sizes": split_sizes,
|
|
1295
|
+
"out_features": out_features_total,
|
|
1296
|
+
"in_features": in_features,
|
|
1297
|
+
"field_tensors": field_tensors,
|
|
1298
|
+
"info": split_info,
|
|
1299
|
+
}
|
|
1300
|
+
for name, tensor in sub_tensors:
|
|
1301
|
+
if tensor is None or name in split_share_fields or tensor.dim() <= 1:
|
|
1302
|
+
_fill_sub_maps(sub_maps, name, tensor)
|
|
1303
|
+
continue
|
|
1304
|
+
split_dim = split_dims.get(name)
|
|
1305
|
+
if split_dim is None:
|
|
1306
|
+
if tensor.size(0) == out_features_total:
|
|
1307
|
+
split_dim = 0
|
|
1308
|
+
elif tensor.dim() > 1 and tensor.size(1) == out_features_total:
|
|
1309
|
+
split_dim = 1
|
|
1310
|
+
else:
|
|
1311
|
+
split_dim = 0
|
|
1312
|
+
handler = _get_split_handler(split_info, name, split_handlers)
|
|
1313
|
+
chunks = _call_split_handler(
|
|
1314
|
+
handler,
|
|
1315
|
+
src=tensor,
|
|
1316
|
+
dim=split_dim,
|
|
1317
|
+
split_sizes=split_sizes,
|
|
1318
|
+
context=dict(base_ctx, split_dim=split_dim),
|
|
1319
|
+
)
|
|
1320
|
+
if chunks is None:
|
|
1321
|
+
if tensor.dim() <= split_dim or tensor.size(split_dim) != out_features_total:
|
|
1322
|
+
got_size = "n/a" if tensor.dim() <= split_dim else tensor.size(split_dim)
|
|
1323
|
+
raise ValueError(
|
|
1324
|
+
f"Cannot split '{k}' quantized tensor '{name}': "
|
|
1325
|
+
f"expected size({split_dim})={out_features_total}, got {got_size}."
|
|
1326
|
+
)
|
|
1327
|
+
chunks = torch.split(tensor, split_sizes, dim=split_dim)
|
|
1328
|
+
for sub_map, chunk in zip(sub_maps, chunks):
|
|
1329
|
+
sub_map[name] = chunk
|
|
1330
|
+
|
|
1331
|
+
create_fn = getattr(weight.__class__, "create", None)
|
|
1332
|
+
if not callable(create_fn):
|
|
1333
|
+
raise ValueError(f"Quantized weight class '{weight.__class__.__name__}' has no create()")
|
|
1334
|
+
create_sig = inspect.signature(create_fn)
|
|
1335
|
+
base_kwargs = {
|
|
1336
|
+
"qtype": getattr(weight, "qtype", None),
|
|
1337
|
+
"axis": getattr(weight, "axis", None),
|
|
1338
|
+
"stride": weight.stride(),
|
|
1339
|
+
"dtype": weight.dtype,
|
|
1340
|
+
"activation_qtype": getattr(weight, "activation_qtype", None),
|
|
1341
|
+
"requires_grad": weight.requires_grad,
|
|
1342
|
+
"group_size": getattr(weight, "_group_size", None),
|
|
1343
|
+
"device": weight.device,
|
|
1344
|
+
}
|
|
1345
|
+
|
|
1346
|
+
qmodule_cls = module.__class__
|
|
1347
|
+
for sub_name, sub_size, sub_map, sub_bias in zip(
|
|
1348
|
+
mapped_modules, split_sizes, sub_maps, sub_biases
|
|
1349
|
+
):
|
|
950
1350
|
with init_empty_weights():
|
|
951
|
-
sub_module =
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
1351
|
+
sub_module = qmodule_cls(
|
|
1352
|
+
in_features,
|
|
1353
|
+
sub_size,
|
|
1354
|
+
bias=bias is not None,
|
|
1355
|
+
device="cpu",
|
|
1356
|
+
dtype=weight.dtype,
|
|
1357
|
+
weights=module.weight_qtype,
|
|
1358
|
+
activations=module.activation_qtype,
|
|
1359
|
+
optimizer=module.optimizer,
|
|
1360
|
+
quantize_input=True,
|
|
1361
|
+
)
|
|
1362
|
+
size = list(weight.size())
|
|
1363
|
+
if size:
|
|
1364
|
+
size[0] = sub_size
|
|
1365
|
+
base_kwargs["size"] = tuple(size)
|
|
1366
|
+
create_kwargs = {}
|
|
1367
|
+
missing = []
|
|
1368
|
+
for name, param in create_sig.parameters.items():
|
|
1369
|
+
if name == "self":
|
|
1370
|
+
continue
|
|
1371
|
+
if name in sub_map:
|
|
1372
|
+
create_kwargs[name] = sub_map[name]
|
|
1373
|
+
elif name in base_kwargs and base_kwargs[name] is not None:
|
|
1374
|
+
create_kwargs[name] = base_kwargs[name]
|
|
1375
|
+
elif param.default is param.empty:
|
|
1376
|
+
missing.append(name)
|
|
1377
|
+
if missing:
|
|
1378
|
+
raise ValueError(
|
|
1379
|
+
f"Unable to rebuild quantized weight for '{k}.{sub_name}': "
|
|
1380
|
+
f"missing {missing}."
|
|
1381
|
+
)
|
|
1382
|
+
sub_weight = create_fn(**create_kwargs)
|
|
1383
|
+
sub_module.weight = torch.nn.Parameter(sub_weight, requires_grad=weight.requires_grad)
|
|
1384
|
+
if sub_bias is not None:
|
|
1385
|
+
sub_module.bias = torch.nn.Parameter(sub_bias)
|
|
955
1386
|
sub_module.optimizer = module.optimizer
|
|
956
1387
|
sub_module.weight_qtype = module.weight_qtype
|
|
1388
|
+
sub_module.activation_qtype = module.activation_qtype
|
|
957
1389
|
setattr(parent_module, sub_name, sub_module)
|
|
958
|
-
# del _data, _scale, _subdata, sub_d
|
|
959
1390
|
else:
|
|
960
1391
|
sub_data = torch.split(weight, split_sizes, dim=0)
|
|
961
|
-
sub_bias = torch.split(bias, split_sizes, dim=0)
|
|
1392
|
+
sub_bias = torch.split(bias, split_sizes, dim=0) if bias is not None else [None] * len(split_sizes)
|
|
962
1393
|
for sub_name, subdata, subbias in zip(mapped_modules, sub_data, sub_bias):
|
|
963
1394
|
with init_empty_weights():
|
|
964
1395
|
sub_module = torch.nn.Linear( subdata.shape[1], subdata.shape[0], bias=bias != None, device ="cpu", dtype=weight.dtype)
|
|
@@ -975,7 +1406,39 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
975
1406
|
|
|
976
1407
|
loras_model_data = getattr(model, "_loras_model_data", None)
|
|
977
1408
|
if loras_model_data == None:
|
|
978
|
-
|
|
1409
|
+
merged_loras_model_data = {}
|
|
1410
|
+
merged_loras_shortcuts = {}
|
|
1411
|
+
sub_loras = {}
|
|
1412
|
+
for submodule_name, submodule in model.named_modules():
|
|
1413
|
+
if submodule is model:
|
|
1414
|
+
continue
|
|
1415
|
+
sub_model_data = getattr(submodule, "_loras_model_data", None)
|
|
1416
|
+
if sub_model_data:
|
|
1417
|
+
submodule._lora_owner = model
|
|
1418
|
+
sub_loras[submodule_name] = submodule
|
|
1419
|
+
for k, v in sub_model_data.items():
|
|
1420
|
+
if k not in merged_loras_model_data:
|
|
1421
|
+
merged_loras_model_data[k] = v
|
|
1422
|
+
sub_shortcuts = getattr(submodule, "_loras_model_shortcuts", None)
|
|
1423
|
+
if sub_shortcuts:
|
|
1424
|
+
prefix = f"{submodule_name}." if submodule_name else ""
|
|
1425
|
+
for k, v in sub_shortcuts.items():
|
|
1426
|
+
merged_key = k
|
|
1427
|
+
if prefix:
|
|
1428
|
+
if k:
|
|
1429
|
+
merged_key = f"{prefix}{k}"
|
|
1430
|
+
else:
|
|
1431
|
+
merged_key = submodule_name
|
|
1432
|
+
if merged_key not in merged_loras_shortcuts:
|
|
1433
|
+
merged_loras_shortcuts[merged_key] = v
|
|
1434
|
+
if merged_loras_model_data:
|
|
1435
|
+
model._loras_model_data = merged_loras_model_data
|
|
1436
|
+
if merged_loras_shortcuts:
|
|
1437
|
+
model._loras_model_shortcuts = merged_loras_shortcuts
|
|
1438
|
+
model._subloras = sub_loras
|
|
1439
|
+
loras_model_data = merged_loras_model_data
|
|
1440
|
+
else:
|
|
1441
|
+
raise Exception(f"No Loras has been declared for this model while creating the corresponding offload object")
|
|
979
1442
|
|
|
980
1443
|
if not check_only:
|
|
981
1444
|
unload_loras_from_model(model)
|
|
@@ -1027,7 +1490,7 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1027
1490
|
|
|
1028
1491
|
if split_linear_modules_map != None:
|
|
1029
1492
|
new_state_dict = dict()
|
|
1030
|
-
suffixes = [(".alpha", -2, False), (".lora_B.weight", -3, True), (".lora_A.weight", -3, False)]
|
|
1493
|
+
suffixes = [(".alpha", -2, False), (".lora_B.weight", -3, True), (".lora_A.weight", -3, False), (".lora_up.weight", -3, True), (".lora_down.weight", -3, False),(".dora_scale", -2, False),]
|
|
1031
1494
|
for module_name, module_data in state_dict.items():
|
|
1032
1495
|
name_parts = module_name.split(".")
|
|
1033
1496
|
for suffix, pos, any_split in suffixes:
|
|
@@ -1052,22 +1515,25 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1052
1515
|
|
|
1053
1516
|
if not fail:
|
|
1054
1517
|
pos = first_key.find(".")
|
|
1055
|
-
prefix = first_key[0:pos]
|
|
1056
|
-
if prefix
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1518
|
+
prefix = first_key[0:pos+1]
|
|
1519
|
+
if prefix in ["diffusion_model.", "transformer."]:
|
|
1520
|
+
prefixes = ("diffusion_model.", "transformer.")
|
|
1521
|
+
new_state_dict = {}
|
|
1522
|
+
for k, v in state_dict.items():
|
|
1523
|
+
for candidate in prefixes:
|
|
1524
|
+
if k.startswith(candidate):
|
|
1525
|
+
k = k[len(candidate) :]
|
|
1526
|
+
break
|
|
1527
|
+
new_state_dict[k] = v
|
|
1528
|
+
state_dict = new_state_dict
|
|
1062
1529
|
|
|
1063
|
-
state_dict = { k[ len(prefix) + 1:]: v for k, v in state_dict.items() if k.startswith(prefix) }
|
|
1064
1530
|
clean_up = True
|
|
1065
1531
|
|
|
1066
1532
|
keys = list(state_dict.keys())
|
|
1067
1533
|
|
|
1068
1534
|
lora_alphas = {}
|
|
1069
1535
|
for k in keys:
|
|
1070
|
-
if "alpha"
|
|
1536
|
+
if k.endswith(".alpha"):
|
|
1071
1537
|
alpha_value = state_dict.pop(k)
|
|
1072
1538
|
if torch.is_tensor(alpha_value):
|
|
1073
1539
|
alpha_value = float(alpha_value.item())
|
|
@@ -1075,17 +1541,19 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1075
1541
|
|
|
1076
1542
|
invalid_keys = []
|
|
1077
1543
|
unexpected_keys = []
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
diff_b = None
|
|
1082
|
-
diff = None
|
|
1544
|
+
new_state_dict = {}
|
|
1545
|
+
for k in list(state_dict.keys()):
|
|
1546
|
+
v = state_dict.pop(k)
|
|
1547
|
+
lora_A = lora_B = diff_b = diff = lora_key = dora_scale = None
|
|
1083
1548
|
if k.endswith(".diff"):
|
|
1084
1549
|
diff = v
|
|
1085
1550
|
module_name = k[ : -5]
|
|
1086
1551
|
elif k.endswith(".diff_b"):
|
|
1087
1552
|
diff_b = v
|
|
1088
1553
|
module_name = k[ : -7]
|
|
1554
|
+
elif k.endswith(".dora_scale"):
|
|
1555
|
+
dora_scale = v
|
|
1556
|
+
module_name = k[ : -11]
|
|
1089
1557
|
else:
|
|
1090
1558
|
pos = k.rfind(".lora_")
|
|
1091
1559
|
if pos <=0:
|
|
@@ -1118,30 +1586,33 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1118
1586
|
if ignore_model_variations:
|
|
1119
1587
|
skip = True
|
|
1120
1588
|
else:
|
|
1121
|
-
msg = f"Lora '{path}': Lora A dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[1]}, lora A = {v.shape[1]}). It is likely this Lora has been made for another version of this model."
|
|
1589
|
+
msg = f"Lora '{path}/{module_name}': Lora A dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[1]}, lora A = {v.shape[1]}). It is likely this Lora has been made for another version of this model."
|
|
1122
1590
|
error_msg = append(error_msg, msg)
|
|
1123
1591
|
fail = True
|
|
1124
1592
|
break
|
|
1593
|
+
v = lora_A = lora_A.to(module.weight.dtype)
|
|
1125
1594
|
elif lora_B != None:
|
|
1126
1595
|
rank = lora_B.shape[1]
|
|
1127
1596
|
if module_shape[0] != v.shape[0]:
|
|
1128
1597
|
if ignore_model_variations:
|
|
1129
1598
|
skip = True
|
|
1130
1599
|
else:
|
|
1131
|
-
msg = f"Lora '{path}': Lora B dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[0]}, lora B = {v.shape[0]}). It is likely this Lora has been made for another version of this model."
|
|
1600
|
+
msg = f"Lora '{path}/{module_name}': Lora B dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[0]}, lora B = {v.shape[0]}). It is likely this Lora has been made for another version of this model."
|
|
1132
1601
|
error_msg = append(error_msg, msg)
|
|
1133
1602
|
fail = True
|
|
1134
1603
|
break
|
|
1604
|
+
v = lora_B = lora_B.to(module.weight.dtype)
|
|
1135
1605
|
elif diff != None:
|
|
1136
1606
|
lora_B = diff
|
|
1137
1607
|
if module_shape != v.shape:
|
|
1138
1608
|
if ignore_model_variations:
|
|
1139
1609
|
skip = True
|
|
1140
1610
|
else:
|
|
1141
|
-
msg = f"Lora '{path}': Lora shape is not compatible with model '{_get_module_name(model)}' (model = {module_shape
|
|
1611
|
+
msg = f"Lora '{path}/{module_name}': Lora shape is not compatible with model '{_get_module_name(model)}' (model = {module_shape}, lora = {v.shape}). It is likely this Lora has been made for another version of this model."
|
|
1142
1612
|
error_msg = append(error_msg, msg)
|
|
1143
1613
|
fail = True
|
|
1144
1614
|
break
|
|
1615
|
+
v = lora_B = lora_B.to(module.weight.dtype)
|
|
1145
1616
|
elif diff_b != None:
|
|
1146
1617
|
rank = diff_b.shape[0]
|
|
1147
1618
|
if not hasattr(module, "bias"):
|
|
@@ -1160,26 +1631,42 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1160
1631
|
error_msg = append(error_msg, msg)
|
|
1161
1632
|
fail = True
|
|
1162
1633
|
break
|
|
1163
|
-
|
|
1634
|
+
v = diff_b = diff_b.to(module.weight.dtype)
|
|
1635
|
+
elif dora_scale != None:
|
|
1636
|
+
rank = dora_scale.shape[1]
|
|
1637
|
+
if module_shape[0] != v.shape[0]:
|
|
1638
|
+
if ignore_model_variations:
|
|
1639
|
+
skip = True
|
|
1640
|
+
else:
|
|
1641
|
+
msg = f"Lora '{path}': Dora Scale dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[0]}, dora scale = {v.shape[0]}). It is likely this Dora has been made for another version of this model."
|
|
1642
|
+
error_msg = append(error_msg, msg)
|
|
1643
|
+
fail = True
|
|
1644
|
+
break
|
|
1645
|
+
v = dora_scale = dora_scale.to(module.weight.dtype)
|
|
1164
1646
|
if not check_only:
|
|
1647
|
+
new_state_dict[k] = v
|
|
1648
|
+
v = None
|
|
1165
1649
|
loras_module_data = loras_model_data.get(module, None)
|
|
1166
1650
|
assert loras_module_data != None
|
|
1167
1651
|
loras_adapter_data = loras_module_data.get(adapter_name, None)
|
|
1168
1652
|
if loras_adapter_data == None:
|
|
1169
|
-
loras_adapter_data = [None, None, None, 1.]
|
|
1653
|
+
loras_adapter_data = [None, None, None, None, 1.]
|
|
1654
|
+
module.any_dora = False
|
|
1170
1655
|
loras_module_data[adapter_name] = loras_adapter_data
|
|
1171
1656
|
if lora_A != None:
|
|
1172
|
-
loras_adapter_data[0] = lora_A
|
|
1657
|
+
loras_adapter_data[0] = lora_A
|
|
1173
1658
|
elif lora_B != None:
|
|
1174
|
-
loras_adapter_data[1] = lora_B
|
|
1659
|
+
loras_adapter_data[1] = lora_B
|
|
1660
|
+
elif dora_scale != None:
|
|
1661
|
+
loras_adapter_data[3] = dora_scale
|
|
1662
|
+
loras_module_data["any_dora"] = True
|
|
1175
1663
|
else:
|
|
1176
|
-
loras_adapter_data[2] = diff_b
|
|
1177
|
-
if rank != None:
|
|
1178
|
-
alpha_key = k[:-len(
|
|
1664
|
+
loras_adapter_data[2] = diff_b
|
|
1665
|
+
if rank != None and lora_key is not None and "lora" in lora_key:
|
|
1666
|
+
alpha_key = k[:-len(lora_key)] + "alpha"
|
|
1179
1667
|
alpha = lora_alphas.get(alpha_key, None)
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
lora_A = lora_B = diff = diff_b = v = loras_module_data = loras_adapter_data = lora_alphas = None
|
|
1668
|
+
if alpha is not None: loras_adapter_data[4] = alpha / rank
|
|
1669
|
+
lora_A = lora_B = diff = diff_b = v = loras_module_data = loras_adapter_data = lora_alphas = dora_scale = None
|
|
1183
1670
|
|
|
1184
1671
|
if len(invalid_keys) > 0:
|
|
1185
1672
|
msg = f"Lora '{path}' contains non Lora keys '{trunc(invalid_keys,200)}'"
|
|
@@ -1202,7 +1689,7 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
|
|
|
1202
1689
|
if not check_only:
|
|
1203
1690
|
# model._loras_tied_weights[adapter_name] = tied_weights
|
|
1204
1691
|
if pinnedLora:
|
|
1205
|
-
pinned_sd_list.append(
|
|
1692
|
+
pinned_sd_list.append(new_state_dict)
|
|
1206
1693
|
pinned_names_list.append(path)
|
|
1207
1694
|
# _pin_sd_to_memory(state_dict, path)
|
|
1208
1695
|
|
|
@@ -1250,6 +1737,7 @@ def sync_models_loras(model, model2):
|
|
|
1250
1737
|
|
|
1251
1738
|
def unload_loras_from_model(model):
|
|
1252
1739
|
if model is None: return
|
|
1740
|
+
if not hasattr(model, "_loras_model_data"): return
|
|
1253
1741
|
for _, v in model._loras_model_data.items():
|
|
1254
1742
|
v.clear()
|
|
1255
1743
|
for _, v in model._loras_model_shortcuts.items():
|
|
@@ -1264,9 +1752,25 @@ def unload_loras_from_model(model):
|
|
|
1264
1752
|
|
|
1265
1753
|
|
|
1266
1754
|
def set_step_no_for_lora(model, step_no):
|
|
1755
|
+
target = getattr(model, "_lora_owner", None)
|
|
1756
|
+
while target is not None and target is not model:
|
|
1757
|
+
model = target
|
|
1758
|
+
target = getattr(model, "_lora_owner", None)
|
|
1267
1759
|
model._lora_step_no = step_no
|
|
1760
|
+
sub_loras = getattr(model, "_subloras", None)
|
|
1761
|
+
if sub_loras:
|
|
1762
|
+
submodules = sub_loras.values() if isinstance(sub_loras, dict) else sub_loras
|
|
1763
|
+
for submodule in submodules:
|
|
1764
|
+
if submodule is model:
|
|
1765
|
+
continue
|
|
1766
|
+
submodule._lora_step_no = step_no
|
|
1268
1767
|
|
|
1269
1768
|
def activate_loras(model, lora_nos, lora_multi = None):
|
|
1769
|
+
target = getattr(model, "_lora_owner", None)
|
|
1770
|
+
while target is not None and target is not model:
|
|
1771
|
+
model = target
|
|
1772
|
+
target = getattr(model, "_lora_owner", None)
|
|
1773
|
+
|
|
1270
1774
|
if not isinstance(lora_nos, list):
|
|
1271
1775
|
lora_nos = [lora_nos]
|
|
1272
1776
|
lora_nos = [str(l) for l in lora_nos]
|
|
@@ -1281,6 +1785,15 @@ def activate_loras(model, lora_nos, lora_multi = None):
|
|
|
1281
1785
|
model._lora_step_no = 0
|
|
1282
1786
|
model._loras_active_adapters = lora_nos
|
|
1283
1787
|
model._loras_scaling = lora_scaling_dict
|
|
1788
|
+
sub_loras = getattr(model, "_subloras", None)
|
|
1789
|
+
if sub_loras:
|
|
1790
|
+
submodules = sub_loras.values() if isinstance(sub_loras, dict) else sub_loras
|
|
1791
|
+
for submodule in submodules:
|
|
1792
|
+
if submodule is model:
|
|
1793
|
+
continue
|
|
1794
|
+
submodule._lora_step_no = 0
|
|
1795
|
+
submodule._loras_active_adapters = lora_nos
|
|
1796
|
+
submodule._loras_scaling = lora_scaling_dict
|
|
1284
1797
|
|
|
1285
1798
|
|
|
1286
1799
|
def move_loras_to_device(model, device="cpu" ):
|
|
@@ -1293,7 +1806,7 @@ def move_loras_to_device(model, device="cpu" ):
|
|
|
1293
1806
|
if ".lora_" in k:
|
|
1294
1807
|
m.to(device)
|
|
1295
1808
|
|
|
1296
|
-
def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, preprocess_sd = None, modules = None, return_shared_modules = None,
|
|
1809
|
+
def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, preprocess_sd = None, modules = None, return_shared_modules = None, default_dtype = torch.bfloat16, ignore_unused_weights = False, configKwargs ={}):
|
|
1297
1810
|
"""
|
|
1298
1811
|
quick version of .LoadfromPretrained of the transformers library
|
|
1299
1812
|
used to build a model and load the corresponding weights (quantized or not)
|
|
@@ -1305,7 +1818,7 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantiza
|
|
|
1305
1818
|
model_path = [model_path]
|
|
1306
1819
|
|
|
1307
1820
|
|
|
1308
|
-
if not builtins.all(file_name.endswith(".sft") or file_name.endswith(".safetensors") or file_name.endswith(".pt") for file_name in model_path):
|
|
1821
|
+
if not builtins.all(file_name.endswith(".sft") or file_name.endswith(".safetensors") or file_name.endswith(".pt") or file_name.endswith(".ckpt") for file_name in model_path):
|
|
1309
1822
|
raise Exception("full model path to file expected")
|
|
1310
1823
|
|
|
1311
1824
|
model_path = [ _get_model(file) for file in model_path]
|
|
@@ -1313,7 +1826,7 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantiza
|
|
|
1313
1826
|
raise Exception("Unable to find file")
|
|
1314
1827
|
|
|
1315
1828
|
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
1316
|
-
if model_path[-1].endswith(".pt"):
|
|
1829
|
+
if model_path[-1].endswith(".pt") or model_path[-1].endswith(".ckpt"):
|
|
1317
1830
|
metadata = None
|
|
1318
1831
|
else:
|
|
1319
1832
|
with safetensors2.safe_open(model_path[-1], writable_tensors =writable_tensors) as f:
|
|
@@ -1376,18 +1889,18 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantiza
|
|
|
1376
1889
|
model = transfomer_class.from_config(transformer_config )
|
|
1377
1890
|
|
|
1378
1891
|
|
|
1379
|
-
torch.set_default_device('cpu')
|
|
1380
1892
|
model.eval().requires_grad_(False)
|
|
1381
1893
|
|
|
1382
1894
|
model._config = transformer_config
|
|
1383
|
-
|
|
1384
|
-
load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, preprocess_sd = preprocess_sd , modules = modules, return_shared_modules = return_shared_modules, verboseLevel=verboseLevel )
|
|
1895
|
+
|
|
1896
|
+
load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, preprocess_sd = preprocess_sd , modules = modules, return_shared_modules = return_shared_modules, default_dtype = default_dtype, ignore_unused_weights = ignore_unused_weights, verboseLevel=verboseLevel )
|
|
1385
1897
|
|
|
1386
1898
|
return model
|
|
1387
1899
|
|
|
1388
1900
|
|
|
1389
1901
|
|
|
1390
|
-
|
|
1902
|
+
@cudacontext("cpu")
|
|
1903
|
+
def load_model_data(model, file_path, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, preprocess_sd = None, postprocess_sd = None, modules = None, return_shared_modules = None, default_dtype = torch.bfloat16, ignore_unused_weights = False, verboseLevel = -1, ignore_missing_keys = False):
|
|
1391
1904
|
"""
|
|
1392
1905
|
Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
|
|
1393
1906
|
"""
|
|
@@ -1423,8 +1936,14 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
|
|
|
1423
1936
|
file_path += modules
|
|
1424
1937
|
modules = None
|
|
1425
1938
|
|
|
1426
|
-
|
|
1427
|
-
|
|
1939
|
+
normalized_paths = []
|
|
1940
|
+
for file in file_path:
|
|
1941
|
+
if isinstance(file, (dict, tuple)):
|
|
1942
|
+
normalized_paths.append(file)
|
|
1943
|
+
else:
|
|
1944
|
+
normalized_paths.append(_get_model(file))
|
|
1945
|
+
file_path = normalized_paths
|
|
1946
|
+
if any(file is None for file in file_path):
|
|
1428
1947
|
raise Exception("Unable to find file")
|
|
1429
1948
|
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
1430
1949
|
|
|
@@ -1442,18 +1961,24 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
|
|
|
1442
1961
|
for no, file in enumerate(file_path):
|
|
1443
1962
|
quantization_map = None
|
|
1444
1963
|
tied_weights_map = None
|
|
1445
|
-
|
|
1964
|
+
metadata = None
|
|
1965
|
+
detected_kind = None
|
|
1966
|
+
if isinstance(file, tuple):
|
|
1967
|
+
if len(file) != 2:
|
|
1968
|
+
raise Exception("Expected a tuple of (state_dict, quantization_map)")
|
|
1969
|
+
state_dict, quantization_map = file
|
|
1970
|
+
elif isinstance(file, dict):
|
|
1971
|
+
state_dict = file
|
|
1972
|
+
elif not (".safetensors" in file or ".sft" in file):
|
|
1446
1973
|
if pinToMemory:
|
|
1447
1974
|
raise Exception("Pinning to memory while loading only supported for safe tensors files")
|
|
1448
|
-
state_dict = torch.load(file, weights_only=
|
|
1975
|
+
state_dict = torch.load(file, weights_only=False, map_location="cpu")
|
|
1449
1976
|
if "module" in state_dict:
|
|
1450
1977
|
state_dict = state_dict["module"]
|
|
1451
|
-
|
|
1452
1978
|
else:
|
|
1453
1979
|
basename = os.path.basename(file)
|
|
1454
1980
|
|
|
1455
1981
|
if "-of-" in basename:
|
|
1456
|
-
metadata = None
|
|
1457
1982
|
file_parts= basename.split("-")
|
|
1458
1983
|
parts_max = int(file_parts[-1][:5])
|
|
1459
1984
|
state_dict = {}
|
|
@@ -1463,29 +1988,50 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
|
|
|
1463
1988
|
state_dict.update(sd)
|
|
1464
1989
|
else:
|
|
1465
1990
|
state_dict, metadata = _safetensors_load_file(file, writable_tensors =writable_tensors)
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1991
|
+
|
|
1992
|
+
if preprocess_sd != None:
|
|
1993
|
+
state_dict = preprocess_sd(state_dict)
|
|
1994
|
+
|
|
1995
|
+
if metadata != None:
|
|
1996
|
+
quantization_map = metadata.get("quantization_map", None)
|
|
1997
|
+
config = metadata.get("config", None)
|
|
1998
|
+
if config is not None:
|
|
1999
|
+
model._config = config
|
|
2000
|
+
|
|
2001
|
+
tied_weights_map = metadata.get("tied_weights_map", None)
|
|
2002
|
+
if tied_weights_map != None:
|
|
2003
|
+
for name, tied_weights_list in tied_weights_map.items():
|
|
2004
|
+
mapped_weight = state_dict[name]
|
|
2005
|
+
for tied_weights in tied_weights_list:
|
|
2006
|
+
state_dict[tied_weights] = mapped_weight
|
|
2007
|
+
|
|
2008
|
+
if quantization_map is None and isinstance(file, str):
|
|
2009
|
+
pos = str.rfind(file, ".")
|
|
2010
|
+
if pos > 0:
|
|
2011
|
+
quantization_map_path = file[:pos]
|
|
2012
|
+
quantization_map_path += "_map.json"
|
|
2013
|
+
|
|
2014
|
+
if os.path.isfile(quantization_map_path):
|
|
2015
|
+
with open(quantization_map_path, 'r') as f:
|
|
2016
|
+
quantization_map = json.load(f)
|
|
2017
|
+
|
|
2018
|
+
if quantization_map is None:
|
|
2019
|
+
conv_result = detect_and_convert(state_dict, default_dtype=default_dtype, verboseLevel=verboseLevel)
|
|
2020
|
+
detected_kind = conv_result.get("kind")
|
|
2021
|
+
if conv_result.get("kind") not in ("none", "quanto"):
|
|
2022
|
+
state_dict = conv_result["state_dict"]
|
|
2023
|
+
quantization_map = conv_result["quant_map"]
|
|
2024
|
+
conv_result = None
|
|
2025
|
+
# enable_fp8_fp32_scale_support()
|
|
2026
|
+
|
|
2027
|
+
if detected_kind in (None, "none") and isinstance(file, str) and (".safetensors" in file or ".sft" in file):
|
|
2028
|
+
try:
|
|
2029
|
+
info = detect_safetensors_format(state_dict, verboseLevel=verboseLevel)
|
|
2030
|
+
detected_kind = info.get("kind")
|
|
2031
|
+
except Exception:
|
|
2032
|
+
detected_kind = detected_kind or None
|
|
2033
|
+
if detected_kind not in (None, "none") and isinstance(file, str):
|
|
2034
|
+
cache_quantization_for_file(file, detected_kind or "none")
|
|
1489
2035
|
|
|
1490
2036
|
full_state_dict.update(state_dict)
|
|
1491
2037
|
if quantization_map != None:
|
|
@@ -1504,8 +2050,8 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
|
|
|
1504
2050
|
full_state_dict, full_quantization_map, full_tied_weights_map = None, None, None
|
|
1505
2051
|
|
|
1506
2052
|
# deal if we are trying to load just a sub part of a larger model
|
|
1507
|
-
if
|
|
1508
|
-
state_dict, quantization_map =
|
|
2053
|
+
if postprocess_sd != None:
|
|
2054
|
+
state_dict, quantization_map = postprocess_sd(state_dict, quantization_map)
|
|
1509
2055
|
|
|
1510
2056
|
if modelPrefix != None:
|
|
1511
2057
|
base_model_prefix = modelPrefix + "."
|
|
@@ -1513,11 +2059,21 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
|
|
|
1513
2059
|
if quantization_map != None:
|
|
1514
2060
|
quantization_map = filter_state_dict(quantization_map,base_model_prefix)
|
|
1515
2061
|
|
|
2062
|
+
post_load_hooks = []
|
|
2063
|
+
if quantization_map:
|
|
2064
|
+
quantization_map, post_load_hooks = apply_pre_quantization(
|
|
2065
|
+
model,
|
|
2066
|
+
state_dict,
|
|
2067
|
+
quantization_map,
|
|
2068
|
+
default_dtype=default_dtype,
|
|
2069
|
+
verboseLevel=verboseLevel,
|
|
2070
|
+
)
|
|
2071
|
+
|
|
1516
2072
|
if len(quantization_map) == 0:
|
|
1517
|
-
if any("quanto" in file for file in file_path) and not do_quantize:
|
|
2073
|
+
if any(isinstance(file, str) and "quanto" in file for file in file_path) and not do_quantize:
|
|
1518
2074
|
print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
|
|
1519
2075
|
else:
|
|
1520
|
-
_requantize(model, state_dict, quantization_map)
|
|
2076
|
+
_requantize(model, state_dict, quantization_map, default_dtype=default_dtype)
|
|
1521
2077
|
|
|
1522
2078
|
|
|
1523
2079
|
|
|
@@ -1530,13 +2086,25 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
|
|
|
1530
2086
|
base_model_prefix = k[:-len(missing_keys[0])]
|
|
1531
2087
|
break
|
|
1532
2088
|
if base_model_prefix == None:
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
2089
|
+
if not ignore_missing_keys:
|
|
2090
|
+
raise Exception(f"Missing keys: {missing_keys}")
|
|
2091
|
+
else:
|
|
2092
|
+
state_dict = filter_state_dict(state_dict, base_model_prefix)
|
|
2093
|
+
missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
|
|
2094
|
+
if len(missing_keys) > 0 and not ignore_missing_keys:
|
|
2095
|
+
raise Exception(f"Missing keys: {missing_keys}")
|
|
1536
2096
|
|
|
1537
2097
|
del state_dict
|
|
1538
2098
|
|
|
1539
|
-
if
|
|
2099
|
+
if post_load_hooks:
|
|
2100
|
+
for hook in post_load_hooks:
|
|
2101
|
+
try:
|
|
2102
|
+
hook(model)
|
|
2103
|
+
except Exception as e:
|
|
2104
|
+
if verboseLevel >= 2:
|
|
2105
|
+
print(f"Post-load hook skipped: {e}")
|
|
2106
|
+
|
|
2107
|
+
if len(unexpected_keys) > 0 and verboseLevel >=2 and not ignore_unused_weights:
|
|
1540
2108
|
print(f"Unexpected keys while loading '{file_path}': {unexpected_keys}")
|
|
1541
2109
|
|
|
1542
2110
|
for k,p in model.named_parameters():
|
|
@@ -1728,6 +2296,35 @@ class HfHook:
|
|
|
1728
2296
|
def detach_hook(self, module):
|
|
1729
2297
|
return module
|
|
1730
2298
|
|
|
2299
|
+
def _mm_lora_linear_forward(module, *args, **kwargs):
|
|
2300
|
+
loras_data = getattr(module, "_mm_lora_data", None)
|
|
2301
|
+
if not loras_data:
|
|
2302
|
+
return module._mm_lora_old_forward(*args, **kwargs)
|
|
2303
|
+
if not hasattr(module, "_mm_manager"):
|
|
2304
|
+
pass
|
|
2305
|
+
return module._mm_manager._lora_linear_forward(
|
|
2306
|
+
module._mm_lora_model,
|
|
2307
|
+
module,
|
|
2308
|
+
loras_data,
|
|
2309
|
+
*args,
|
|
2310
|
+
**kwargs,
|
|
2311
|
+
)
|
|
2312
|
+
|
|
2313
|
+
|
|
2314
|
+
def _mm_lora_generic_forward(module, *args, **kwargs):
|
|
2315
|
+
loras_data = getattr(module, "_mm_lora_data", None)
|
|
2316
|
+
if not loras_data:
|
|
2317
|
+
return module._mm_lora_old_forward(*args, **kwargs)
|
|
2318
|
+
return module._mm_manager._lora_generic_forward(
|
|
2319
|
+
module._mm_lora_model,
|
|
2320
|
+
module,
|
|
2321
|
+
loras_data,
|
|
2322
|
+
module._mm_lora_old_forward,
|
|
2323
|
+
*args,
|
|
2324
|
+
**kwargs,
|
|
2325
|
+
)
|
|
2326
|
+
|
|
2327
|
+
|
|
1731
2328
|
last_offload_obj = None
|
|
1732
2329
|
class offload:
|
|
1733
2330
|
def __init__(self):
|
|
@@ -1757,6 +2354,7 @@ class offload:
|
|
|
1757
2354
|
global last_offload_obj
|
|
1758
2355
|
last_offload_obj = self
|
|
1759
2356
|
|
|
2357
|
+
self._type_wrappers = {}
|
|
1760
2358
|
|
|
1761
2359
|
def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name, submodule_name):
|
|
1762
2360
|
|
|
@@ -1781,22 +2379,12 @@ class offload:
|
|
|
1781
2379
|
param_size = 0
|
|
1782
2380
|
ref = _get_tensor_ref(p)
|
|
1783
2381
|
tied_param = self.parameters_ref.get(ref, None)
|
|
1784
|
-
|
|
1785
|
-
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
|
-
|
|
1789
|
-
param_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
|
|
1790
|
-
param_size += torch.numel(p._data._data) * p._data._data.element_size()
|
|
1791
|
-
else:
|
|
1792
|
-
param_size += torch.numel(p._scale) * p._scale.element_size()
|
|
1793
|
-
param_size += torch.numel(p._shift) * p._shift.element_size()
|
|
1794
|
-
param_size += torch.numel(p._data._data) * p._data._data.element_size()
|
|
1795
|
-
else:
|
|
1796
|
-
param_size += torch.numel(p._scale) * p._scale.element_size()
|
|
1797
|
-
param_size += torch.numel(p._data) * p._data.element_size()
|
|
2382
|
+
blocks_params.append((submodule, k, p, False, tied_param))
|
|
2383
|
+
sub_tensors = _get_quantized_subtensors(p)
|
|
2384
|
+
if sub_tensors:
|
|
2385
|
+
param_size += _subtensors_nbytes(sub_tensors)
|
|
2386
|
+
del sub_tensors
|
|
1798
2387
|
else:
|
|
1799
|
-
blocks_params.append( (submodule, k, p, False, tied_param) )
|
|
1800
2388
|
param_size += torch.numel(p.data) * p.data.element_size()
|
|
1801
2389
|
|
|
1802
2390
|
|
|
@@ -2091,7 +2679,7 @@ class offload:
|
|
|
2091
2679
|
data = loras_data.get(active_adapter + '_GPU', None)
|
|
2092
2680
|
if data == None:
|
|
2093
2681
|
continue
|
|
2094
|
-
diff_w , _ , diff_b, alpha = data
|
|
2682
|
+
diff_w , _ , diff_b, _, alpha = data
|
|
2095
2683
|
scaling = self._get_lora_scaling( loras_scaling, model, active_adapter) * alpha
|
|
2096
2684
|
if scaling == 0:
|
|
2097
2685
|
continue
|
|
@@ -2117,15 +2705,117 @@ class offload:
|
|
|
2117
2705
|
return ret
|
|
2118
2706
|
|
|
2119
2707
|
|
|
2708
|
+
def _dora_linear_forward(
|
|
2709
|
+
self,
|
|
2710
|
+
model,
|
|
2711
|
+
submodule,
|
|
2712
|
+
adapters_data, # dict: name+"_GPU" -> (A, B, diff_b, g_abs, alpha); g_abs=None means LoRA
|
|
2713
|
+
weight= None,
|
|
2714
|
+
bias = None,
|
|
2715
|
+
original_bias = True,
|
|
2716
|
+
dora_mode: str = "blend", # "ref_exact" | "blend"
|
|
2717
|
+
):
|
|
2718
|
+
active_adapters = getattr(model, "_loras_active_adapters", [])
|
|
2719
|
+
loras_scaling = getattr(model, "_loras_scaling", {})
|
|
2720
|
+
# Snapshot base weight (safe for quantized modules)
|
|
2721
|
+
if weight is None:
|
|
2722
|
+
bias = submodule.bias
|
|
2723
|
+
original_bias = True
|
|
2724
|
+
if isinstance(submodule, QModuleMixin):
|
|
2725
|
+
weight = submodule.weight.view(submodule.weight.shape)
|
|
2726
|
+
else:
|
|
2727
|
+
weight = submodule.weight.clone()
|
|
2728
|
+
|
|
2729
|
+
base_dtype = weight.dtype
|
|
2730
|
+
eps = 1e-8
|
|
2731
|
+
W0 = weight.float()
|
|
2732
|
+
g0 = torch.linalg.vector_norm(W0, dim=1, keepdim=True, dtype=torch.float32).clamp_min(eps) # [out,1]
|
|
2733
|
+
|
|
2734
|
+
# Keep big mats in low precision
|
|
2735
|
+
# Wc = W0 if W0.dtype == compute_dtype else W0.to(compute_dtype)
|
|
2736
|
+
W0 /= g0
|
|
2737
|
+
weight[...] = W0.to(base_dtype)
|
|
2738
|
+
W0 = None
|
|
2739
|
+
|
|
2740
|
+
dir_update = None # Σ s * ((B@A)/g0) in compute_dtype
|
|
2741
|
+
g = None # final magnitude: set absolute (ref_exact) or blended (blend)
|
|
2742
|
+
bias_delta = None # Σ s * diff_b
|
|
2743
|
+
|
|
2744
|
+
# Accumulate DoRA adapters only (g_abs != None)
|
|
2745
|
+
for name in active_adapters:
|
|
2746
|
+
data = adapters_data.get(name + "_GPU", None)
|
|
2747
|
+
if data is None: continue
|
|
2748
|
+
A, B, diff_b, g_abs, alpha = data
|
|
2749
|
+
if g_abs is None: continue
|
|
2750
|
+
|
|
2751
|
+
s = self._get_lora_scaling(loras_scaling, model, name) * float(alpha)
|
|
2752
|
+
if s == 0: continue
|
|
2753
|
+
|
|
2754
|
+
# Direction update in V-space with row-wise 1/g0
|
|
2755
|
+
if (A is not None) and (B is not None):
|
|
2756
|
+
dV = torch.mm(B, A) # [out,in], compute_dtype
|
|
2757
|
+
dV /= g0 # row-wise divide
|
|
2758
|
+
dV.mul_(s)
|
|
2759
|
+
dir_update = dV if dir_update is None else dir_update.add_(dV)
|
|
2760
|
+
|
|
2761
|
+
|
|
2762
|
+
if dora_mode == "ref_exact":
|
|
2763
|
+
# absolute magnitude (last one wins if multiple DoRAs present)
|
|
2764
|
+
g = g_abs
|
|
2765
|
+
elif dora_mode == "blend":
|
|
2766
|
+
# blend towards absolute magnitude proportional to s
|
|
2767
|
+
if g is None:
|
|
2768
|
+
g = g0.clone()
|
|
2769
|
+
g.add_(g_abs.sub(g0), alpha=s)
|
|
2770
|
+
else:
|
|
2771
|
+
raise ValueError(f"Unknown dora_mode: {dora_mode}")
|
|
2772
|
+
|
|
2773
|
+
# Optional bias deltas (not in reference, but harmless if present)
|
|
2774
|
+
if diff_b is not None:
|
|
2775
|
+
db = diff_b.mul(s)
|
|
2776
|
+
bias_delta = db if bias_delta is None else bias_delta.add_(db)
|
|
2777
|
+
db = None
|
|
2778
|
+
|
|
2779
|
+
if g is None:
|
|
2780
|
+
g = g0 # no magnitude provided -> keep original
|
|
2781
|
+
|
|
2782
|
+
# Re-normalize rows if we changed direction
|
|
2783
|
+
if dir_update is not None:
|
|
2784
|
+
weight.add_(dir_update)
|
|
2785
|
+
V = weight.float()
|
|
2786
|
+
Vn = torch.linalg.vector_norm(V, dim=1, keepdim=True, dtype=torch.float32).clamp_min(eps)
|
|
2787
|
+
V /= Vn
|
|
2788
|
+
V *= g
|
|
2789
|
+
weight[...] = V.to(base_dtype)
|
|
2790
|
+
V = None
|
|
2791
|
+
else:
|
|
2792
|
+
weight *= g
|
|
2793
|
+
# Recompose adapted weight; cast back to module dtype
|
|
2794
|
+
|
|
2795
|
+
# Merge DoRA bias delta safely
|
|
2796
|
+
if bias_delta is not None:
|
|
2797
|
+
if bias is None:
|
|
2798
|
+
bias = bias_delta
|
|
2799
|
+
else:
|
|
2800
|
+
bias = bias.clone() if original_bias else bias
|
|
2801
|
+
bias.add_(bias_delta)
|
|
2802
|
+
|
|
2803
|
+
return weight, bias
|
|
2804
|
+
|
|
2805
|
+
|
|
2806
|
+
|
|
2120
2807
|
def _lora_linear_forward(self, model, submodule, loras_data, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|
2121
2808
|
weight = submodule.weight
|
|
2809
|
+
bias = submodule.bias
|
|
2122
2810
|
active_adapters = model._loras_active_adapters
|
|
2123
2811
|
loras_scaling = model._loras_scaling
|
|
2812
|
+
any_dora = loras_data.get("any_dora", False)
|
|
2813
|
+
is_nvfp4 = getattr(submodule, "is_nvfp4", False)
|
|
2124
2814
|
training = False
|
|
2125
2815
|
|
|
2126
|
-
dtype = weight.dtype
|
|
2127
|
-
if weight.shape[-1] < x.shape[-2]: # sum base weight and lora matrices instead of applying input on each sub lora matrice if input is too large. This will save a lot VRAM and compute
|
|
2128
|
-
|
|
2816
|
+
dtype = weight.dtype
|
|
2817
|
+
if (weight.shape[-1] < x.shape[-2] and False or any_dora): # sum base weight and lora matrices instead of applying input on each sub lora matrice if input is too large. This will save a lot VRAM and compute
|
|
2818
|
+
original_bias = True
|
|
2129
2819
|
original_bias = True
|
|
2130
2820
|
if len(active_adapters) > 0:
|
|
2131
2821
|
if isinstance(submodule, QModuleMixin):
|
|
@@ -2136,10 +2826,17 @@ class offload:
|
|
|
2136
2826
|
data = loras_data.get(active_adapter + '_GPU', None)
|
|
2137
2827
|
if data == None:
|
|
2138
2828
|
continue
|
|
2139
|
-
lora_A_weight, lora_B_weight, diff_b, alpha = data
|
|
2829
|
+
lora_A_weight, lora_B_weight, diff_b, g_abs, alpha = data
|
|
2140
2830
|
scaling = self._get_lora_scaling(loras_scaling, model, active_adapter) * alpha
|
|
2141
|
-
if scaling == 0:
|
|
2831
|
+
if scaling == 0 or g_abs is not None:
|
|
2142
2832
|
continue
|
|
2833
|
+
target_dtype = weight.dtype
|
|
2834
|
+
if lora_A_weight is not None and lora_A_weight.dtype != target_dtype:
|
|
2835
|
+
lora_A_weight = lora_A_weight.to(target_dtype)
|
|
2836
|
+
if lora_B_weight is not None and lora_B_weight.dtype != target_dtype:
|
|
2837
|
+
lora_B_weight = lora_B_weight.to(target_dtype)
|
|
2838
|
+
if diff_b is not None and diff_b.dtype != target_dtype:
|
|
2839
|
+
diff_b = diff_b.to(target_dtype)
|
|
2143
2840
|
if lora_A_weight != None:
|
|
2144
2841
|
weight.addmm_(lora_B_weight, lora_A_weight, alpha= scaling )
|
|
2145
2842
|
|
|
@@ -2152,41 +2849,60 @@ class offload:
|
|
|
2152
2849
|
original_bias = False
|
|
2153
2850
|
bias.add_(diff_b, alpha=scaling)
|
|
2154
2851
|
# base_weight += scaling * lora_B_weight @ lora_A_weight
|
|
2852
|
+
|
|
2853
|
+
if any_dora :
|
|
2854
|
+
weight, bias = self._dora_linear_forward(model, submodule, loras_data, weight, bias, original_bias)
|
|
2155
2855
|
if training:
|
|
2156
2856
|
pass
|
|
2157
2857
|
# result = torch.nn.functional.linear(dropout(x), base_weight, bias=submodule.bias)
|
|
2158
2858
|
else:
|
|
2159
|
-
|
|
2859
|
+
base_bias = bias
|
|
2860
|
+
if base_bias is not None and base_bias.dtype != x.dtype:
|
|
2861
|
+
base_bias = base_bias.to(x.dtype)
|
|
2862
|
+
result = torch.nn.functional.linear(x, weight, bias=base_bias)
|
|
2160
2863
|
|
|
2161
2864
|
else:
|
|
2162
|
-
|
|
2865
|
+
base_bias = bias
|
|
2866
|
+
if base_bias is not None and base_bias.dtype != x.dtype:
|
|
2867
|
+
base_bias = base_bias.to(x.dtype)
|
|
2868
|
+
result = torch.nn.functional.linear(x, weight, bias=base_bias)
|
|
2163
2869
|
|
|
2164
2870
|
if len(active_adapters) > 0:
|
|
2165
|
-
|
|
2871
|
+
compute_dtype = torch.float32 if is_nvfp4 else result.dtype
|
|
2872
|
+
if result.dtype != compute_dtype:
|
|
2873
|
+
result = result.to(compute_dtype)
|
|
2874
|
+
x = x.to(compute_dtype)
|
|
2166
2875
|
|
|
2167
2876
|
for active_adapter in active_adapters:
|
|
2168
2877
|
data = loras_data.get(active_adapter + '_GPU', None)
|
|
2169
2878
|
if data == None:
|
|
2170
2879
|
continue
|
|
2171
|
-
lora_A, lora_B, diff_b, alpha = data
|
|
2880
|
+
lora_A, lora_B, diff_b, g_abs, alpha = data
|
|
2172
2881
|
# dropout = self.lora_dropout[active_adapter]
|
|
2173
2882
|
scaling = self._get_lora_scaling(loras_scaling, model, active_adapter) * alpha
|
|
2174
|
-
if scaling == 0:
|
|
2883
|
+
if scaling == 0 or g_abs is not None:
|
|
2175
2884
|
continue
|
|
2885
|
+
target_dtype = result.dtype
|
|
2886
|
+
if lora_A is not None and lora_A.dtype != target_dtype:
|
|
2887
|
+
lora_A = lora_A.to(target_dtype)
|
|
2888
|
+
if lora_B is not None and lora_B.dtype != target_dtype:
|
|
2889
|
+
lora_B = lora_B.to(target_dtype)
|
|
2890
|
+
if diff_b is not None and diff_b.dtype != target_dtype:
|
|
2891
|
+
diff_b = diff_b.to(target_dtype)
|
|
2892
|
+
|
|
2176
2893
|
if lora_A == None:
|
|
2177
2894
|
result.add_(diff_b, alpha=scaling)
|
|
2178
2895
|
else:
|
|
2179
|
-
|
|
2180
|
-
|
|
2181
|
-
|
|
2182
|
-
|
|
2183
|
-
|
|
2184
|
-
|
|
2185
|
-
y = torch.nn.functional.linear(x, lora_A, bias=None)
|
|
2186
|
-
y = torch.nn.functional.linear(y, lora_B, bias=diff_b)
|
|
2187
|
-
y*= scaling
|
|
2188
|
-
result+= y
|
|
2896
|
+
x_2d = x.reshape(-1, x.shape[-1])
|
|
2897
|
+
result_2d = result.reshape(-1, result.shape[-1])
|
|
2898
|
+
y = x_2d @ lora_A.T
|
|
2899
|
+
result_2d.addmm_(y, lora_B.T, beta=1, alpha=scaling)
|
|
2900
|
+
if diff_b is not None:
|
|
2901
|
+
result_2d.add_(diff_b, alpha=scaling)
|
|
2189
2902
|
del y
|
|
2903
|
+
target_dtype = input_dtype if is_nvfp4 else dtype
|
|
2904
|
+
if result.dtype != target_dtype:
|
|
2905
|
+
result = result.to(target_dtype)
|
|
2190
2906
|
|
|
2191
2907
|
return result
|
|
2192
2908
|
|
|
@@ -2198,22 +2914,14 @@ class offload:
|
|
|
2198
2914
|
assert submodule_name not in loras_model_shortcuts
|
|
2199
2915
|
loras_model_shortcuts[submodule_name] = loras_data
|
|
2200
2916
|
loras_model_data[submodule] = loras_data
|
|
2917
|
+
submodule._mm_lora_data = loras_data
|
|
2918
|
+
submodule._mm_lora_model = current_model
|
|
2919
|
+
submodule._mm_lora_old_forward = old_forward
|
|
2201
2920
|
|
|
2202
|
-
if isinstance(submodule, torch.nn.Linear):
|
|
2203
|
-
|
|
2204
|
-
if len(loras_data) == 0:
|
|
2205
|
-
return old_forward(*args, **kwargs)
|
|
2206
|
-
else:
|
|
2207
|
-
submodule.aaa = submodule_name
|
|
2208
|
-
return self._lora_linear_forward(current_model, submodule, loras_data, *args, **kwargs)
|
|
2209
|
-
target_fn = lora_linear_forward
|
|
2921
|
+
if isinstance(submodule, torch.nn.Linear) or getattr(submodule, "is_nvfp4", False):
|
|
2922
|
+
target_fn = _mm_lora_linear_forward
|
|
2210
2923
|
else:
|
|
2211
|
-
|
|
2212
|
-
if len(loras_data) == 0:
|
|
2213
|
-
return old_forward(*args, **kwargs)
|
|
2214
|
-
else:
|
|
2215
|
-
return self._lora_generic_forward(current_model, submodule, loras_data, old_forward, *args, **kwargs)
|
|
2216
|
-
target_fn = lora_generic_forward
|
|
2924
|
+
target_fn = _mm_lora_generic_forward
|
|
2217
2925
|
return functools.update_wrapper(functools.partial(target_fn, submodule), old_forward)
|
|
2218
2926
|
|
|
2219
2927
|
def ensure_model_loaded(self, model_id):
|
|
@@ -2236,10 +2944,65 @@ class offload:
|
|
|
2236
2944
|
|
|
2237
2945
|
# need to be registered before the forward not to be break the efficiency of the compilation chain
|
|
2238
2946
|
# it should be at the top of the compilation as this type of hook in the middle of a chain seems to break memory performance
|
|
2239
|
-
target_module.register_forward_pre_hook(preload_blocks_for_compile)
|
|
2947
|
+
target_module.register_forward_pre_hook(preload_blocks_for_compile)
|
|
2948
|
+
|
|
2240
2949
|
|
|
2241
2950
|
|
|
2242
|
-
|
|
2951
|
+
|
|
2952
|
+
@torch._dynamo.disable
|
|
2953
|
+
def _pre_check(self, module):
|
|
2954
|
+
model_id = getattr(module, "_mm_model_id", None)
|
|
2955
|
+
blocks_name = getattr(module, "_mm_blocks_name", None)
|
|
2956
|
+
|
|
2957
|
+
self.ensure_model_loaded(model_id)
|
|
2958
|
+
if blocks_name is None:
|
|
2959
|
+
if self.ready_to_check_mem():
|
|
2960
|
+
self.empty_cache_if_needed()
|
|
2961
|
+
elif blocks_name != self.loaded_blocks[model_id] and \
|
|
2962
|
+
blocks_name not in self.preloaded_blocks_per_model[model_id]:
|
|
2963
|
+
self.gpu_load_blocks(model_id, blocks_name)
|
|
2964
|
+
|
|
2965
|
+
def _get_wrapper_for_type(self, mod_cls):
|
|
2966
|
+
fn = self._type_wrappers.get(mod_cls)
|
|
2967
|
+
if fn is not None:
|
|
2968
|
+
return fn
|
|
2969
|
+
|
|
2970
|
+
# Unique function name per class -> unique compiled code object
|
|
2971
|
+
fname = f"_mm_wrap_{mod_cls.__module__.replace('.', '_')}_{mod_cls.__name__}"
|
|
2972
|
+
|
|
2973
|
+
# Keep body minimal; all heavy/offload logic runs out-of-graph in _pre_check
|
|
2974
|
+
# Include __TYPE_CONST in the code so the bytecode/consts differ per class.
|
|
2975
|
+
src = f"""
|
|
2976
|
+
def {fname}(module, *args, **kwargs):
|
|
2977
|
+
_ = __TYPE_CONST # anchor type as a constant to make code object unique per class
|
|
2978
|
+
nada = "{fname}"
|
|
2979
|
+
mgr = module._mm_manager
|
|
2980
|
+
mgr._pre_check(module)
|
|
2981
|
+
return module._mm_forward(*args, **kwargs) #{fname}
|
|
2982
|
+
"""
|
|
2983
|
+
ns = {"__TYPE_CONST": mod_cls}
|
|
2984
|
+
exec(src, ns) # compile a new function object/code object for this class
|
|
2985
|
+
fn = ns[fname]
|
|
2986
|
+
self._type_wrappers[mod_cls] = fn
|
|
2987
|
+
return fn
|
|
2988
|
+
|
|
2989
|
+
def hook_check_load_into_GPU_if_needed(
|
|
2990
|
+
self, target_module, model, model_id, blocks_name, previous_method, context
|
|
2991
|
+
):
|
|
2992
|
+
# store instance data on the module (not captured by the wrapper)
|
|
2993
|
+
target_module._mm_manager = self
|
|
2994
|
+
target_module._mm_model_id = model_id
|
|
2995
|
+
target_module._mm_blocks_name = blocks_name
|
|
2996
|
+
target_module._mm_forward = previous_method
|
|
2997
|
+
|
|
2998
|
+
# per-TYPE wrapper (unique bytecode per class, reused across instances of that class)
|
|
2999
|
+
wrapper_fn = self._get_wrapper_for_type(type(target_module))
|
|
3000
|
+
|
|
3001
|
+
# bind as a bound method (no partial/closures)
|
|
3002
|
+
# target_module.forward = types.MethodType(wrapper_fn, target_module)
|
|
3003
|
+
target_module.forward = functools.update_wrapper(functools.partial(wrapper_fn, target_module), previous_method)
|
|
3004
|
+
|
|
3005
|
+
def hook_check_load_into_GPU_if_needed_default(self, target_module, model, model_id, blocks_name, previous_method, context):
|
|
2243
3006
|
|
|
2244
3007
|
dtype = model._dtype
|
|
2245
3008
|
qint4quantization = isinstance(target_module, QModuleMixin) and target_module.weight!= None and target_module.weight.qtype == qint4
|
|
@@ -2259,22 +3022,33 @@ class offload:
|
|
|
2259
3022
|
target_module.forward = target_module._mm_forward
|
|
2260
3023
|
return
|
|
2261
3024
|
|
|
2262
|
-
def
|
|
3025
|
+
def check_load_into_GPU_needed():
|
|
2263
3026
|
self.ensure_model_loaded(model_id)
|
|
2264
3027
|
if blocks_name == None:
|
|
2265
3028
|
if self.ready_to_check_mem():
|
|
2266
3029
|
self.empty_cache_if_needed()
|
|
2267
3030
|
elif blocks_name != self.loaded_blocks[model_id] and blocks_name not in self.preloaded_blocks_per_model[model_id]:
|
|
2268
3031
|
self.gpu_load_blocks(model_id, blocks_name)
|
|
2269
|
-
if qint4quantization and dtype !=None:
|
|
2270
|
-
|
|
2271
|
-
|
|
2272
|
-
|
|
3032
|
+
# if qint4quantization and dtype !=None:
|
|
3033
|
+
# args, kwargs = self.move_args_to_gpu(dtype, *args, **kwargs)
|
|
3034
|
+
|
|
3035
|
+
if isinstance(target_module, torch.nn.Linear):
|
|
3036
|
+
def check_load_into_GPU_needed_linear(module, *args, **kwargs):
|
|
3037
|
+
check_load_into_GPU_needed()
|
|
3038
|
+
return previous_method(*args, **kwargs) # linear
|
|
3039
|
+
check_load_into_GPU_needed_module = check_load_into_GPU_needed_linear
|
|
3040
|
+
else:
|
|
3041
|
+
def check_load_into_GPU_needed_other(module, *args, **kwargs):
|
|
3042
|
+
check_load_into_GPU_needed()
|
|
3043
|
+
return previous_method(*args, **kwargs) # other
|
|
3044
|
+
check_load_into_GPU_needed_module = check_load_into_GPU_needed_other
|
|
2273
3045
|
|
|
2274
3046
|
setattr(target_module, "_mm_id", model_id)
|
|
3047
|
+
setattr(target_module, "_mm_manager", self)
|
|
2275
3048
|
setattr(target_module, "_mm_forward", previous_method)
|
|
2276
3049
|
|
|
2277
|
-
setattr(target_module, "forward", functools.update_wrapper(functools.partial(
|
|
3050
|
+
setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_load_into_GPU_needed_module, target_module), previous_method) )
|
|
3051
|
+
# target_module.register_forward_pre_hook(check_empty_cuda_cache)
|
|
2278
3052
|
|
|
2279
3053
|
|
|
2280
3054
|
def hook_change_module(self, target_module, model, model_id, module_id, previous_method, previous_method_name ):
|
|
@@ -2300,7 +3074,7 @@ class offload:
|
|
|
2300
3074
|
if not self.verboseLevel >=1:
|
|
2301
3075
|
return
|
|
2302
3076
|
|
|
2303
|
-
if module_id == None or module_id =='':
|
|
3077
|
+
if previous_method_name =="forward" and (module_id == None or module_id ==''):
|
|
2304
3078
|
model_name = model._get_name()
|
|
2305
3079
|
print(f"Hooked to model '{model_id}' ({model_name})")
|
|
2306
3080
|
|
|
@@ -2415,7 +3189,7 @@ class offload:
|
|
|
2415
3189
|
|
|
2416
3190
|
|
|
2417
3191
|
|
|
2418
|
-
def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, partialPinning = False, loras = None, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, workingVRAM = None, asyncTransfers = True, compile = False, convertWeightsFloatTo = torch.bfloat16, perc_reserved_mem_max = 0, coTenantsMap = None, verboseLevel = -1):
|
|
3192
|
+
def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, partialPinning = False, loras = None, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, workingVRAM = None, asyncTransfers = True, compile = False, convertWeightsFloatTo = torch.bfloat16, perc_reserved_mem_max = 0, coTenantsMap = None, vram_safety_coefficient = 0.8, compile_mode ="default", verboseLevel = -1):
|
|
2419
3193
|
"""Hook to a pipeline or a group of modules in order to reduce their VRAM requirements:
|
|
2420
3194
|
pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
|
|
2421
3195
|
quantizeTransformer: set True by default will quantize on the fly the video / image model
|
|
@@ -2424,6 +3198,8 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
|
|
|
2424
3198
|
budgets: 0 by default (unlimited). If non 0, it corresponds to the maximum size in MB that every model will occupy at any moment
|
|
2425
3199
|
(in fact the real usage is twice this number). It is very efficient to reduce VRAM consumption but this feature may be very slow
|
|
2426
3200
|
if pinnedMemory is not enabled
|
|
3201
|
+
vram_safety_coefficient: float between 0 and 1 (exclusive), default 0.8. Sets the maximum portion of VRAM that can be used for models.
|
|
3202
|
+
Lower values provide more safety margin but may reduce performance.
|
|
2427
3203
|
"""
|
|
2428
3204
|
self = offload()
|
|
2429
3205
|
self.verboseLevel = verboseLevel
|
|
@@ -2439,7 +3215,11 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
|
|
|
2439
3215
|
return float(b[:-1]) * self.device_mem_capacity
|
|
2440
3216
|
else:
|
|
2441
3217
|
return b * ONE_MB
|
|
2442
|
-
|
|
3218
|
+
|
|
3219
|
+
# Validate vram_safety_coefficient
|
|
3220
|
+
if not isinstance(vram_safety_coefficient, float) or vram_safety_coefficient <= 0 or vram_safety_coefficient >= 1:
|
|
3221
|
+
raise ValueError("vram_safety_coefficient must be a float between 0 and 1 (exclusive)")
|
|
3222
|
+
|
|
2443
3223
|
budget = 0
|
|
2444
3224
|
if not budgets is None:
|
|
2445
3225
|
if isinstance(budgets , dict):
|
|
@@ -2523,26 +3303,22 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
|
|
|
2523
3303
|
|
|
2524
3304
|
current_model_size = 0
|
|
2525
3305
|
model_dtype = getattr(current_model, "_model_dtype", None)
|
|
3306
|
+
fp8_fallback_dtype = None
|
|
2526
3307
|
# if model_dtype == None:
|
|
2527
3308
|
# model_dtype = getattr(current_model, "dtype", None)
|
|
2528
3309
|
for _ , m in current_model.named_modules():
|
|
2529
3310
|
ignore_dtype = hasattr(m, "_lock_dtype")
|
|
2530
3311
|
for n, p in m.named_parameters(recurse = False):
|
|
2531
3312
|
p.requires_grad = False
|
|
2532
|
-
|
|
2533
|
-
|
|
2534
|
-
|
|
2535
|
-
|
|
2536
|
-
|
|
2537
|
-
|
|
2538
|
-
|
|
2539
|
-
|
|
2540
|
-
|
|
2541
|
-
else:
|
|
2542
|
-
current_model_size += torch.numel(p._scale) * p._scale.element_size()
|
|
2543
|
-
current_model_size += torch.numel(p._data) * p._data.element_size()
|
|
2544
|
-
dtype = p._scale.dtype
|
|
2545
|
-
|
|
3313
|
+
sub_tensors = _get_quantized_subtensors(p)
|
|
3314
|
+
if sub_tensors:
|
|
3315
|
+
current_model_size += _subtensors_nbytes(sub_tensors)
|
|
3316
|
+
dtype = sub_tensors[0][1].dtype
|
|
3317
|
+
for name, tensor in sub_tensors:
|
|
3318
|
+
if name in ("scale", "scale_shift"):
|
|
3319
|
+
dtype = tensor.dtype
|
|
3320
|
+
break
|
|
3321
|
+
del sub_tensors
|
|
2546
3322
|
else:
|
|
2547
3323
|
if not ignore_dtype:
|
|
2548
3324
|
dtype = p.data.dtype
|
|
@@ -2551,14 +3327,25 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
|
|
|
2551
3327
|
dtype = convertWeightsFloatTo if model_dtype == None else model_dtype
|
|
2552
3328
|
if dtype != torch.float32:
|
|
2553
3329
|
p.data = p.data.to(dtype)
|
|
2554
|
-
if model_dtype
|
|
2555
|
-
|
|
3330
|
+
if model_dtype is None:
|
|
3331
|
+
if dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
|
3332
|
+
if fp8_fallback_dtype is None:
|
|
3333
|
+
fp8_fallback_dtype = dtype
|
|
3334
|
+
else:
|
|
3335
|
+
model_dtype = dtype
|
|
2556
3336
|
else:
|
|
2557
3337
|
if model_dtype != dtype:
|
|
2558
|
-
|
|
2559
|
-
|
|
3338
|
+
if (
|
|
3339
|
+
dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
|
|
3340
|
+
or model_dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
|
|
3341
|
+
):
|
|
3342
|
+
pass
|
|
3343
|
+
else:
|
|
3344
|
+
assert model_dtype == dtype
|
|
2560
3345
|
current_model_size += torch.numel(p.data) * p.data.element_size()
|
|
2561
|
-
|
|
3346
|
+
if model_dtype is None and fp8_fallback_dtype is not None:
|
|
3347
|
+
model_dtype = fp8_fallback_dtype
|
|
3348
|
+
current_model._dtype = model_dtype
|
|
2562
3349
|
for b in current_model.buffers():
|
|
2563
3350
|
# do not convert 32 bits float to 16 bits since buffers are few (and potential gain low) and usually they are needed for precision calculation (for instance Rope)
|
|
2564
3351
|
current_model_size += torch.numel(b.data) * b.data.element_size()
|
|
@@ -2584,14 +3371,14 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
|
|
|
2584
3371
|
model_budget = new_budget if model_budget == 0 or new_budget < model_budget else model_budget
|
|
2585
3372
|
if model_budget > 0 and model_budget > current_model_size:
|
|
2586
3373
|
model_budget = 0
|
|
2587
|
-
coef =
|
|
3374
|
+
coef =vram_safety_coefficient
|
|
2588
3375
|
if current_model_size > coef * self.device_mem_capacity and model_budget == 0 or model_budget > coef * self.device_mem_capacity:
|
|
2589
3376
|
if verboseLevel >= 1:
|
|
2590
3377
|
if model_budget == 0:
|
|
2591
|
-
print(f"Model '{model_id}' is too large ({current_model_size/ONE_MB:0.1f} MB) to fit entirely in {coef * 100}% of the VRAM (max capacity is {coef * self.device_mem_capacity/ONE_MB}) MB)")
|
|
3378
|
+
print(f"Model '{model_id}' is too large ({current_model_size/ONE_MB:0.1f} MB) to fit entirely in {coef * 100:.0f}% of the VRAM (max capacity is {coef * self.device_mem_capacity/ONE_MB:0.1f}) MB)")
|
|
2592
3379
|
else:
|
|
2593
3380
|
print(f"Budget ({budget/ONE_MB:0.1f} MB) for Model '{model_id}' is too important so that this model can fit in the VRAM (max capacity is {self.device_mem_capacity/ONE_MB}) MB)")
|
|
2594
|
-
print(f"Budget allocation for this model has been consequently reduced to the
|
|
3381
|
+
print(f"Budget allocation for this model has been consequently reduced to the {coef * 100:.0f}% of max GPU Memory ({coef * self.device_mem_capacity/ONE_MB:0.1f} MB). This may not leave enough working VRAM and you will probably need to define manually a lower budget for this model.")
|
|
2595
3382
|
model_budget = coef * self.device_mem_capacity
|
|
2596
3383
|
|
|
2597
3384
|
|
|
@@ -2607,19 +3394,7 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
|
|
|
2607
3394
|
for model_id in models:
|
|
2608
3395
|
current_model: torch.nn.Module = models[model_id]
|
|
2609
3396
|
towers_names, towers_modules = _detect_main_towers(current_model)
|
|
2610
|
-
# compile main iterative modules stacks ("towers")
|
|
2611
3397
|
compilationInThisOne = compileAllModels or model_id in modelsToCompile
|
|
2612
|
-
if compilationInThisOne:
|
|
2613
|
-
if self.verboseLevel>=1:
|
|
2614
|
-
if len(towers_modules)>0:
|
|
2615
|
-
formated_tower_names = [name + '*' for name in towers_names]
|
|
2616
|
-
print(f"Pytorch compilation of '{model_id}' is scheduled for these modules : {formated_tower_names}.")
|
|
2617
|
-
else:
|
|
2618
|
-
print(f"Pytorch compilation of model '{model_id}' is not yet supported.")
|
|
2619
|
-
|
|
2620
|
-
for submodel in towers_modules:
|
|
2621
|
-
submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode="default" ) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
|
|
2622
|
-
#dynamic=True,
|
|
2623
3398
|
|
|
2624
3399
|
if pinAllModels or model_id in modelsToPin:
|
|
2625
3400
|
if hasattr(current_model,"_already_pinned"):
|
|
@@ -2627,6 +3402,7 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
|
|
|
2627
3402
|
print(f"Model '{model_id}' already pinned to reserved memory")
|
|
2628
3403
|
else:
|
|
2629
3404
|
_pin_to_memory(current_model, model_id, partialPinning= partialPinning, pinnedPEFTLora = pinnedPEFTLora, perc_reserved_mem_max = perc_reserved_mem_max, verboseLevel=verboseLevel)
|
|
3405
|
+
|
|
2630
3406
|
current_budget = model_budgets[model_id]
|
|
2631
3407
|
cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq, is_mod_seq = None, None, None, -1, False
|
|
2632
3408
|
self.loaded_blocks[model_id] = None
|
|
@@ -2665,8 +3441,6 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
|
|
|
2665
3441
|
# print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
|
|
2666
3442
|
top_submodule = len(submodule_name.split("."))==1
|
|
2667
3443
|
offload_hooks = submodule._offload_hooks if hasattr(submodule, "_offload_hooks") else []
|
|
2668
|
-
if len(offload_hooks) > 0:
|
|
2669
|
-
pass
|
|
2670
3444
|
assert top_submodule or len(offload_hooks) == 0, "custom offload hooks can only be set at the of the module"
|
|
2671
3445
|
submodule_method_names = ["forward"] + offload_hooks
|
|
2672
3446
|
for submodule_method_name in submodule_method_names:
|
|
@@ -2676,16 +3450,32 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
|
|
|
2676
3450
|
else:
|
|
2677
3451
|
submodule_method = getattr(submodule, submodule_method_name)
|
|
2678
3452
|
if callable(submodule_method):
|
|
2679
|
-
if top_submodule and cur_blocks_name is None:
|
|
3453
|
+
if top_submodule and cur_blocks_name is None and not (any_lora and len(submodule._parameters)):
|
|
2680
3454
|
self.hook_change_module(submodule, current_model, model_id, submodule_name, submodule_method, submodule_method_name)
|
|
2681
3455
|
elif compilationInThisOne and submodule in towers_modules:
|
|
2682
3456
|
self.hook_preload_blocks_for_compilation(submodule, model_id, cur_blocks_name, context = submodule_name )
|
|
2683
3457
|
else:
|
|
2684
|
-
|
|
2685
|
-
|
|
3458
|
+
if compilationInThisOne: #and False
|
|
3459
|
+
self.hook_check_load_into_GPU_if_needed(submodule, current_model, model_id, cur_blocks_name, submodule_method, context = submodule_name )
|
|
3460
|
+
else:
|
|
3461
|
+
self.hook_check_load_into_GPU_if_needed_default(submodule, current_model, model_id, cur_blocks_name, submodule_method, context = submodule_name )
|
|
3462
|
+
|
|
2686
3463
|
self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name, submodule_name)
|
|
2687
3464
|
|
|
2688
3465
|
|
|
3466
|
+
# compile main iterative modules stacks ("towers")
|
|
3467
|
+
if compilationInThisOne:
|
|
3468
|
+
if self.verboseLevel>=1:
|
|
3469
|
+
if len(towers_modules)>0:
|
|
3470
|
+
formated_tower_names = [name + '*' for name in towers_names]
|
|
3471
|
+
print(f"Pytorch compilation of '{model_id}' is scheduled for these modules : {formated_tower_names}.")
|
|
3472
|
+
else:
|
|
3473
|
+
print(f"Pytorch compilation of model '{model_id}' is not yet supported.")
|
|
3474
|
+
|
|
3475
|
+
for submodel in towers_modules:
|
|
3476
|
+
submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode= compile_mode) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
|
|
3477
|
+
#dynamic=True,
|
|
3478
|
+
|
|
2689
3479
|
self.tune_preloading(model_id, current_budget, towers_names)
|
|
2690
3480
|
self.parameters_ref = {}
|
|
2691
3481
|
|