mmgp 3.5.7__py3-none-any.whl → 3.6.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mmgp/offload.py CHANGED
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.5.7 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.6.11 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -60,16 +60,23 @@ import functools
60
60
  import sys
61
61
  import os
62
62
  import json
63
+ import inspect
63
64
  import psutil
64
65
  import builtins
65
66
  from accelerate import init_empty_weights
66
-
67
+ from functools import wraps
68
+ import functools
69
+ import types
67
70
 
68
71
  from mmgp import safetensors2
69
72
  from mmgp import profile_type
70
-
73
+ from .quant_router import (
74
+ apply_pre_quantization,
75
+ cache_quantization_for_file,
76
+ detect_and_convert,
77
+ detect_safetensors_format,
78
+ )
71
79
  from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QLinear, QTensor, quantize_module, register_qmodule
72
-
73
80
  # support for Embedding module quantization that is not supported by default by quanto
74
81
  @register_qmodule(torch.nn.Embedding)
75
82
  class QEmbedding(QModuleMixin, torch.nn.Embedding):
@@ -83,8 +90,36 @@ class QEmbedding(QModuleMixin, torch.nn.Embedding):
83
90
  return torch.nn.functional.embedding( input, self.qweight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse )
84
91
 
85
92
 
93
+
94
+ def cudacontext(device):
95
+ def decorator(func):
96
+ @wraps(func)
97
+ def wrapper(*args, **kwargs):
98
+ with torch.device(device):
99
+ return func(*args, **kwargs)
100
+ return wrapper
101
+ return decorator
102
+
103
+
86
104
  shared_state = {}
87
105
 
106
+ def get_cache(cache_name):
107
+ all_cache = shared_state.get("_cache", None)
108
+ if all_cache is None:
109
+ all_cache = {}
110
+ shared_state["_cache"]= all_cache
111
+ cache = all_cache.get(cache_name, None)
112
+ if cache is None:
113
+ cache = {}
114
+ all_cache[cache_name] = cache
115
+ return cache
116
+
117
+ def clear_caches():
118
+ all_cache = shared_state.get("_cache", None)
119
+ if all_cache is not None:
120
+ all_cache.clear()
121
+
122
+
88
123
  mmm = safetensors2.mmm
89
124
 
90
125
  default_verboseLevel = 1
@@ -277,23 +312,112 @@ def _safetensors_load_file(file_path, writable_tensors = True):
277
312
 
278
313
  def _force_load_buffer(p):
279
314
  # To do : check if buffer was persistent and transfer state, or maybe swap keep already this property ?
280
- q = torch.nn.Buffer(p + 0)
315
+ q = torch.nn.Buffer(p.clone())
281
316
  torch.utils.swap_tensors(p, q)
282
317
  del q
283
318
 
284
319
  def _force_load_parameter(p):
285
- q = torch.nn.Parameter(p + 0)
320
+ q = torch.nn.Parameter(p.clone())
286
321
  torch.utils.swap_tensors(p, q)
287
322
  del q
288
323
 
289
- def _get_tensor_ref(p):
290
- if isinstance(p, QTensor):
291
- if p._qtype == qint4:
292
- return p._data._data.data_ptr()
324
+ def _unwrap_quantized_tensor(tensor):
325
+ if hasattr(tensor, "_data") and torch.is_tensor(tensor._data):
326
+ return tensor._data
327
+ return tensor
328
+
329
+ def _qtensor_get_quantized_subtensors(self):
330
+ subtensors = []
331
+ if getattr(self, "_qtype", None) == qint4:
332
+ data = _unwrap_quantized_tensor(self._data)
333
+ subtensors.append(("data", data))
334
+ if hasattr(self, "_scale_shift") and self._scale_shift is not None:
335
+ subtensors.append(("scale_shift", self._scale_shift))
336
+ else:
337
+ if hasattr(self, "_scale") and self._scale is not None:
338
+ subtensors.append(("scale", self._scale))
339
+ if hasattr(self, "_shift") and self._shift is not None:
340
+ subtensors.append(("shift", self._shift))
341
+ return subtensors
342
+
343
+ if hasattr(self, "_data"):
344
+ data = _unwrap_quantized_tensor(self._data)
345
+ subtensors.append(("data", data))
346
+ if hasattr(self, "_scale") and self._scale is not None:
347
+ subtensors.append(("scale", self._scale))
348
+ return subtensors
349
+
350
+ def _qtensor_set_quantized_subtensors(self, sub_tensors):
351
+ if isinstance(sub_tensors, dict):
352
+ sub_map = sub_tensors
353
+ else:
354
+ sub_map = {name: tensor for name, tensor in sub_tensors}
355
+
356
+ data = sub_map.get("data", None)
357
+ if data is not None:
358
+ if hasattr(self, "_data") and hasattr(self._data, "_data") and torch.is_tensor(self._data._data):
359
+ self._data._data = data
360
+ else:
361
+ self._data = data
362
+
363
+ if getattr(self, "_qtype", None) == qint4:
364
+ if "scale_shift" in sub_map and sub_map["scale_shift"] is not None:
365
+ self._scale_shift = sub_map["scale_shift"]
293
366
  else:
294
- return p._data.data_ptr()
295
- else:
296
- return p.data_ptr()
367
+ if "scale" in sub_map and sub_map["scale"] is not None:
368
+ self._scale = sub_map["scale"]
369
+ if "shift" in sub_map and sub_map["shift"] is not None:
370
+ self._shift = sub_map["shift"]
371
+ else:
372
+ if "scale" in sub_map and sub_map["scale"] is not None:
373
+ self._scale = sub_map["scale"]
374
+
375
+ if not hasattr(QTensor, "get_quantized_subtensors"):
376
+ QTensor.get_quantized_subtensors = _qtensor_get_quantized_subtensors
377
+ if not hasattr(QTensor, "set_quantized_subtensors"):
378
+ QTensor.set_quantized_subtensors = _qtensor_set_quantized_subtensors
379
+
380
+ def _get_quantized_subtensors(p):
381
+ getter = getattr(p, "get_quantized_subtensors", None)
382
+ if getter is None:
383
+ return None
384
+ sub_tensors = getter()
385
+ if not sub_tensors:
386
+ return None
387
+ if isinstance(sub_tensors, dict):
388
+ sub_tensors = list(sub_tensors.items())
389
+ out = []
390
+ for name, tensor in sub_tensors:
391
+ if tensor is None:
392
+ continue
393
+ if torch.is_tensor(tensor):
394
+ out.append((name, tensor))
395
+ return out if out else None
396
+
397
+ def _set_quantized_subtensors(p, sub_tensors):
398
+ setter = getattr(p, "set_quantized_subtensors", None)
399
+ if setter is None:
400
+ return False
401
+ setter(sub_tensors)
402
+ return True
403
+
404
+ def _subtensors_nbytes(sub_tensors):
405
+ return sum(torch.numel(t) * t.element_size() for _, t in sub_tensors)
406
+
407
+ def _subtensors_itemsize(sub_tensors, fallback):
408
+ for _, t in sub_tensors:
409
+ return t.element_size()
410
+ return fallback
411
+
412
+ def _get_tensor_ref(p):
413
+ sub_tensors = _get_quantized_subtensors(p)
414
+ if sub_tensors:
415
+ for _, t in sub_tensors:
416
+ ref = t.data_ptr()
417
+ del sub_tensors
418
+ return ref
419
+ del sub_tensors
420
+ return p.data_ptr()
297
421
 
298
422
 
299
423
  BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
@@ -516,25 +640,18 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
516
640
  tied_weights_last = f"{match_name} <-> {n}"
517
641
  tied_weights[n] = match_name
518
642
  else:
519
- if isinstance(p, QTensor):
520
- if p._qtype == qint4:
521
- if p._data._data.is_pinned():
522
- params_dict[n] = (None, False)
523
- continue
524
- if hasattr(p,"_scale_shift"):
525
- length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale_shift) * p._scale_shift.element_size()
526
- else:
527
- length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale) * p._scale.element_size() + torch.numel(p._shift) * p._shift.element_size()
528
- else:
529
- length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
530
- if p._data.is_pinned():
531
- params_dict[n] = (None, False)
532
- continue
643
+ sub_tensors = _get_quantized_subtensors(p)
644
+ if sub_tensors:
645
+ if builtins.all(t.is_pinned() for _, t in sub_tensors):
646
+ params_dict[n] = (None, False)
647
+ del sub_tensors
648
+ continue
649
+ length = _subtensors_nbytes(sub_tensors)
533
650
  else:
534
651
  if p.data.is_pinned():
535
652
  params_dict[n] = (None, False)
536
653
  continue
537
- length = torch.numel(p.data) * p.data.element_size()
654
+ length = torch.numel(p.data) * p.data.element_size()
538
655
 
539
656
  ref_cache[ref] = (n, length)
540
657
  if current_big_tensor_size + length > big_tensor_size and current_big_tensor_size !=0 :
@@ -542,8 +659,11 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
542
659
  current_big_tensor_size = 0
543
660
  big_tensor_no += 1
544
661
 
545
-
546
- itemsize = p.data.dtype.itemsize
662
+ if sub_tensors:
663
+ itemsize = _subtensors_itemsize(sub_tensors, p.data.dtype.itemsize)
664
+ del sub_tensors
665
+ else:
666
+ itemsize = p.data.dtype.itemsize
547
667
  if current_big_tensor_size % itemsize:
548
668
  current_big_tensor_size += itemsize - current_big_tensor_size % itemsize
549
669
  tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
@@ -580,15 +700,11 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
580
700
  q_name = tied_weights.get(n,None)
581
701
  if q_name != None:
582
702
  q , _ = params_dict[q_name]
583
- if isinstance(p, QTensor):
584
- if p._qtype == qint4:
585
- p._data._data = q._data._data
586
- p._scale_shift = q._scale_shift
587
- assert p._data._data.data.is_pinned()
588
- else:
589
- p._data = q._data
590
- p._scale = q._scale
591
- assert p._data.is_pinned()
703
+ sub_tensors = _get_quantized_subtensors(q)
704
+ if sub_tensors:
705
+ sub_map = {name: tensor for name, tensor in sub_tensors}
706
+ _set_quantized_subtensors(p, sub_map)
707
+ del sub_map, sub_tensors
592
708
  else:
593
709
  p.data = q.data
594
710
  assert p.data.is_pinned()
@@ -618,27 +734,21 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
618
734
  total += size
619
735
 
620
736
  current_big_tensor = big_tensors[big_tensor_no]
737
+
621
738
  if is_buffer :
622
739
  _force_load_buffer(p) # otherwise potential memory leak
623
- if isinstance(p, QTensor):
624
- if p._qtype == qint4:
625
- length1 = torch.numel(p._data._data) * p._data._data.element_size()
626
- p._data._data = _move_to_pinned_tensor(p._data._data, current_big_tensor, offset, length1)
627
- if hasattr(p,"_scale_shift"):
628
- length2 = torch.numel(p._scale_shift) * p._scale_shift.element_size()
629
- p._scale_shift = _move_to_pinned_tensor(p._scale_shift, current_big_tensor, offset + length1, length2)
630
- else:
631
- length2 = torch.numel(p._scale) * p._scale.element_size()
632
- p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
633
- length3 = torch.numel(p._shift) * p._shift.element_size()
634
- p._shift = _move_to_pinned_tensor(p._shift, current_big_tensor, offset + length1 + length2, length3)
635
- else:
636
- length1 = torch.numel(p._data) * p._data.element_size()
637
- p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
638
- length2 = torch.numel(p._scale) * p._scale.element_size()
639
- p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
740
+ sub_tensors = _get_quantized_subtensors(p)
741
+ if sub_tensors:
742
+ sub_offset = offset
743
+ new_subs = {}
744
+ for name, tensor in sub_tensors:
745
+ length = torch.numel(tensor) * tensor.element_size()
746
+ new_subs[name] = _move_to_pinned_tensor(tensor, current_big_tensor, sub_offset, length)
747
+ sub_offset += length
748
+ _set_quantized_subtensors(p, new_subs)
749
+ del new_subs, sub_tensors
640
750
  else:
641
- length = torch.numel(p.data) * p.data.element_size()
751
+ length = torch.numel(p.data) * p.data.element_size()
642
752
  p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
643
753
 
644
754
  tensor_no += 1
@@ -666,18 +776,22 @@ def _welcome():
666
776
  if welcome_displayed:
667
777
  return
668
778
  welcome_displayed = True
669
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.5.7) by DeepBeepMeep ************{ENDC}{UNBOLD}")
779
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.6.11) by DeepBeepMeep ************{ENDC}{UNBOLD}")
670
780
 
671
781
  def change_dtype(model, new_dtype, exclude_buffers = False):
672
782
  for submodule_name, submodule in model.named_modules():
673
783
  if hasattr(submodule, "_lock_dtype"):
674
784
  continue
675
785
  for n, p in submodule.named_parameters(recurse = False):
786
+ if isinstance(p, QTensor):
787
+ continue
676
788
  if p.data.dtype != new_dtype:
677
789
  p.data = p.data.to(new_dtype)
678
790
 
679
791
  if not exclude_buffers:
680
792
  for p in submodule.buffers(recurse=False):
793
+ if isinstance(p, QTensor):
794
+ continue
681
795
  if p.data.dtype != new_dtype:
682
796
  p.data = p.data.to(new_dtype)
683
797
 
@@ -751,7 +865,7 @@ def _quantize_submodule(
751
865
  setattr(module, name, None)
752
866
  del param
753
867
 
754
- def _requantize(model: torch.nn.Module, state_dict: dict, quantization_map: dict):
868
+ def _requantize(model: torch.nn.Module, state_dict: dict, quantization_map: dict, default_dtype=None):
755
869
  # change dtype of current meta model parameters because 'requantize' won't update the dtype on non quantized parameters
756
870
  for k, p in model.named_parameters():
757
871
  if not k in quantization_map and k in state_dict:
@@ -770,6 +884,11 @@ def _requantize(model: torch.nn.Module, state_dict: dict, quantization_map: dict
770
884
  if activations == "none":
771
885
  activations = None
772
886
  _quantize_submodule(model, name, m, weights=weights, activations=activations)
887
+ if default_dtype is not None:
888
+ new_module = model.get_submodule(name)
889
+ setter = getattr(new_module, "set_default_dtype", None)
890
+ if callable(setter):
891
+ setter(default_dtype)
773
892
 
774
893
  model._quanto_map = quantization_map
775
894
 
@@ -803,6 +922,7 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
803
922
 
804
923
  cache_ref = {}
805
924
  tied_weights= {}
925
+ reversed_tied_weights= {}
806
926
 
807
927
  for submodule_name, submodule in model_to_quantize.named_modules():
808
928
  if isinstance(submodule, QModuleMixin):
@@ -815,7 +935,9 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
815
935
  ref = _get_tensor_ref(p)
816
936
  match = cache_ref.get(ref, None)
817
937
  if match != None:
818
- tied_weights[submodule_name]= (n, ) + match
938
+ tied_weights[submodule_name]= (n, ) + match
939
+ entries = reversed_tied_weights.get( match, [])
940
+ reversed_tied_weights[match] = entries + [ (p, submodule_name,n)]
819
941
  else:
820
942
  cache_ref[ref] = (submodule_name, n)
821
943
  size += torch.numel(p.data) * sizeofhalffloat
@@ -883,6 +1005,7 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
883
1005
  # force to read non quantized parameters so that their lazy tensors and corresponding mmap are released
884
1006
  # otherwise we may end up keeping in memory both the quantized and the non quantize model
885
1007
  named_modules = {n:m for n,m in model_to_quantize.named_modules()}
1008
+
886
1009
  for module_name, module in named_modules.items():
887
1010
  # do not read quantized weights (detected them directly or behind an adapter)
888
1011
  if isinstance(module, QModuleMixin) or hasattr(module, "base_layer") and isinstance(module.base_layer, QModuleMixin):
@@ -891,12 +1014,18 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
891
1014
  else:
892
1015
  tied_w = tied_weights.get(module_name, None)
893
1016
  for n, p in module.named_parameters(recurse = False):
1017
+
894
1018
  if tied_w != None and n == tied_w[0]:
895
1019
  if isinstance( named_modules[tied_w[1]], QModuleMixin) :
896
1020
  setattr(module, n, None) # release refs of tied weights if source is going to be quantized
897
1021
  # otherwise don't force load as it will be loaded in the source anyway
898
1022
  else:
899
1023
  _force_load_parameter(p)
1024
+ entries = reversed_tied_weights.get( (module_name, n), [])
1025
+ for tied_weight, tied_module_name, tied_weight_name in entries:
1026
+ if n == tied_weight_name:
1027
+ tied_weight.data = p.data
1028
+
900
1029
  del p # del p if not it will still contain a ref to a tensor when leaving the loop
901
1030
  for b in module.buffers(recurse = False):
902
1031
  _force_load_buffer(b)
@@ -927,38 +1056,340 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
927
1056
 
928
1057
  return True
929
1058
 
930
- def split_linear_modules(model, map ):
931
- from optimum.quanto import QModuleMixin, WeightQBytesTensor, QLinear
1059
+ def _as_field_tuple(value):
1060
+ if not value:
1061
+ return ()
1062
+ if isinstance(value, str):
1063
+ return (value,)
1064
+ return tuple(value)
1065
+
1066
+
1067
+ def _get_split_handler(info, field, default_handlers):
1068
+ handlers = info.get("split_handlers") or info.get("field_handlers") or {}
1069
+ if handlers:
1070
+ handler = handlers.get(field)
1071
+ if handler is not None:
1072
+ return handler
1073
+ if default_handlers:
1074
+ return default_handlers.get(field)
1075
+ return None
1076
+
1077
+
1078
+ def _get_split_base_fields(info, split_fields):
1079
+ base_fields = _as_field_tuple(info.get("base_fields") or info.get("base_field"))
1080
+ if base_fields:
1081
+ return base_fields
1082
+ if split_fields:
1083
+ return (next(iter(split_fields.keys())),)
1084
+ return ()
1085
+
1086
+
1087
+ def _merge_share_fields(info, share_fields):
1088
+ info_fields = _as_field_tuple(info.get("share_fields") or info.get("shared_fields"))
1089
+ return tuple(sorted(set(info_fields).union(_as_field_tuple(share_fields))))
1090
+
1091
+
1092
+ def _call_split_handler(handler, *, src, dim, split_sizes, context):
1093
+ if handler is None:
1094
+ return None
1095
+ try:
1096
+ chunks = handler(src=src, dim=dim, split_sizes=split_sizes, context=context)
1097
+ except Exception:
1098
+ return None
1099
+ if not isinstance(chunks, (list, tuple)) or len(chunks) != len(split_sizes):
1100
+ return None
1101
+ return chunks
1102
+
1103
+
1104
+ def _fill_sub_maps(sub_maps, name, value):
1105
+ for sub_map in sub_maps:
1106
+ sub_map[name] = value
1107
+
1108
+
1109
+ def sd_split_linear(
1110
+ state_dict,
1111
+ split_map,
1112
+ split_fields=None,
1113
+ share_fields=None,
1114
+ verboseLevel=1,
1115
+ split_handlers=None,
1116
+ ):
1117
+ if not split_map:
1118
+ return state_dict
1119
+ split_fields = split_fields or {}
1120
+ share_fields = share_fields or ()
1121
+ split_handlers = split_handlers or {}
1122
+ base_fields_by_suffix = {
1123
+ suffix: _get_split_base_fields(info or {}, split_fields)
1124
+ for suffix, info in split_map.items()
1125
+ }
1126
+ def _skip(msg):
1127
+ if verboseLevel >= 2:
1128
+ print(f"[sd_split_linear] Skip {msg}")
1129
+
1130
+ bases = {}
1131
+ for key in state_dict.keys():
1132
+ for suffix, base_fields in base_fields_by_suffix.items():
1133
+ for base_field in base_fields:
1134
+ suffix_token = f"{suffix}.{base_field}"
1135
+ if not key.endswith(suffix_token):
1136
+ continue
1137
+ base = key[: -len("." + base_field)]
1138
+ if base.endswith(suffix):
1139
+ bases[base] = suffix
1140
+ break
1141
+
1142
+ if not bases:
1143
+ return state_dict
1144
+
1145
+ for base, suffix in bases.items():
1146
+ info = split_map.get(suffix) or {}
1147
+ mapped = info.get("mapped_modules") or info.get("mapped_suffixes") or info.get("mapped") or []
1148
+ if not mapped:
1149
+ continue
1150
+
1151
+ base_fields = base_fields_by_suffix.get(suffix) or _get_split_base_fields(info, split_fields)
1152
+ size_field = info.get("size_field") or (base_fields[0] if base_fields else None)
1153
+ size_tensor = state_dict.get(base + "." + size_field) if size_field else None
1154
+ split_dim = info.get("split_dim", 0)
1155
+ split_sizes = list(info.get("split_sizes") or [])
1156
+ if not split_sizes:
1157
+ if size_tensor is None:
1158
+ continue
1159
+ if size_tensor.dim() <= split_dim:
1160
+ _skip(f"{base}: dim={size_tensor.dim()} split_dim={split_dim}")
1161
+ continue
1162
+ out_dim = size_tensor.size(split_dim)
1163
+ if out_dim % len(mapped) != 0:
1164
+ _skip(f"{base}: out_dim={out_dim} not divisible by {len(mapped)}")
1165
+ continue
1166
+ split_sizes = [out_dim // len(mapped)] * len(mapped)
1167
+ elif None in split_sizes:
1168
+ if size_tensor is None:
1169
+ continue
1170
+ if size_tensor.dim() <= split_dim:
1171
+ _skip(f"{base}: dim={size_tensor.dim()} split_dim={split_dim}")
1172
+ continue
1173
+ known = sum(size for size in split_sizes if size is not None)
1174
+ none_count = split_sizes.count(None)
1175
+ remaining = size_tensor.size(split_dim) - known
1176
+ if remaining < 0 or remaining % none_count != 0:
1177
+ _skip(f"{base}: cannot resolve split sizes")
1178
+ continue
1179
+ fill = remaining // none_count
1180
+ split_sizes = [fill if size is None else size for size in split_sizes]
1181
+
1182
+ total = sum(split_sizes)
1183
+ prefix = base[: -len(suffix)]
1184
+ target_bases = [prefix + name for name in mapped]
1185
+ added = 0
1186
+
1187
+ field_tensors = {
1188
+ field: state_dict.get(base + "." + field)
1189
+ for field in set(split_fields.keys()).union(share_fields)
1190
+ }
1191
+ base_ctx = {
1192
+ "state_dict": state_dict,
1193
+ "base": base,
1194
+ "suffix": suffix,
1195
+ "split_sizes": split_sizes,
1196
+ "total": total,
1197
+ "mapped": mapped,
1198
+ "target_bases": target_bases,
1199
+ "verboseLevel": verboseLevel,
1200
+ "split_fields": split_fields,
1201
+ "share_fields": share_fields,
1202
+ "field_tensors": field_tensors,
1203
+ "size_field": size_field,
1204
+ "size_tensor": size_tensor,
1205
+ "split_dim": split_dim,
1206
+ "info": info,
1207
+ }
1208
+ fields_iter = list(split_fields.items()) + [(field, None) for field in share_fields]
1209
+ for field, dim in fields_iter:
1210
+ src = field_tensors.get(field)
1211
+ if src is None:
1212
+ continue
1213
+ if dim is None:
1214
+ for target_base in target_bases:
1215
+ dest_key = target_base + "." + field
1216
+ if dest_key not in state_dict:
1217
+ state_dict[dest_key] = src
1218
+ added += 1
1219
+ continue
1220
+ if src.dim() <= dim:
1221
+ _skip(f"{base}.{field}: dim={src.dim()} split_dim={dim}")
1222
+ continue
1223
+ if src.size(dim) != total:
1224
+ _skip(f"{base}.{field}: size({dim})={src.size(dim)} expected={total}")
1225
+ continue
1226
+ handler = _get_split_handler(info, field, split_handlers)
1227
+ chunks = _call_split_handler(
1228
+ handler,
1229
+ src=src,
1230
+ dim=dim,
1231
+ split_sizes=split_sizes,
1232
+ context=dict(base_ctx, field=field),
1233
+ )
1234
+ if chunks is None:
1235
+ chunks = torch.split(src, split_sizes, dim=dim)
1236
+ for target_base, chunk in zip(target_bases, chunks):
1237
+ if torch.is_tensor(chunk) and not chunk.is_contiguous():
1238
+ chunk = chunk.contiguous()
1239
+ dest_key = target_base + "." + field
1240
+ if dest_key not in state_dict:
1241
+ state_dict[dest_key] = chunk
1242
+ added += 1
1243
+
1244
+ if added:
1245
+ for field in list(split_fields.keys()) + list(share_fields):
1246
+ state_dict.pop(base + "." + field, None)
1247
+ if verboseLevel >= 2:
1248
+ print(f"[sd_split_linear] Split {base} -> {', '.join(mapped)}")
1249
+
1250
+ return state_dict
1251
+
1252
+
1253
+ def split_linear_modules(model, map, split_handlers=None, share_fields=None):
1254
+ from optimum.quanto import QModuleMixin
932
1255
  from accelerate import init_empty_weights
933
1256
 
1257
+ split_handlers = split_handlers or {}
1258
+ share_fields = share_fields or ()
1259
+
934
1260
  modules_dict = { k: m for k, m in model.named_modules()}
935
1261
  for module_suffix, split_info in map.items():
936
1262
  mapped_modules = split_info["mapped_modules"]
937
1263
  split_sizes = split_info["split_sizes"]
1264
+ split_share_fields = _merge_share_fields(split_info, share_fields)
1265
+ split_dims = split_info.get("split_dims") or {}
938
1266
  for k, module in modules_dict.items():
939
1267
  if k.endswith("." + module_suffix):
940
1268
  parent_module = modules_dict[k[:len(k)-len(module_suffix)-1]]
941
1269
  weight = module.weight
942
1270
  bias = getattr(module, "bias", None)
943
1271
  if isinstance(module, QModuleMixin):
944
- _data = weight._data
945
- _scale = weight._scale
946
- sub_data = torch.split(_data, split_sizes, dim=0)
947
- sub_scale = torch.split(_scale, split_sizes, dim=0)
948
- sub_bias = torch.split(bias, split_sizes, dim=0)
949
- for sub_name, _subdata, _subbias, _subscale in zip(mapped_modules, sub_data, sub_bias, sub_scale):
1272
+ out_features_total = weight.size(0)
1273
+ if sum(split_sizes) != out_features_total:
1274
+ raise ValueError(
1275
+ f"Split sizes {split_sizes} do not match out_features {out_features_total} for '{k}'."
1276
+ )
1277
+ in_features = weight.size(1)
1278
+ sub_biases = None
1279
+ if bias is not None and bias.dim() > 0 and bias.size(0) == out_features_total:
1280
+ sub_biases = torch.split(bias, split_sizes, dim=0)
1281
+ else:
1282
+ sub_biases = [bias] * len(split_sizes)
1283
+
1284
+ sub_tensors = _get_quantized_subtensors(weight)
1285
+ if not sub_tensors:
1286
+ raise ValueError(f"Unable to split quantized weight for '{k}'.")
1287
+ sub_maps = [dict() for _ in split_sizes]
1288
+ field_tensors = {name: tensor for name, tensor in sub_tensors}
1289
+ base_ctx = {
1290
+ "module": module,
1291
+ "module_name": k,
1292
+ "module_suffix": module_suffix,
1293
+ "mapped_modules": mapped_modules,
1294
+ "split_sizes": split_sizes,
1295
+ "out_features": out_features_total,
1296
+ "in_features": in_features,
1297
+ "field_tensors": field_tensors,
1298
+ "info": split_info,
1299
+ }
1300
+ for name, tensor in sub_tensors:
1301
+ if tensor is None or name in split_share_fields or tensor.dim() <= 1:
1302
+ _fill_sub_maps(sub_maps, name, tensor)
1303
+ continue
1304
+ split_dim = split_dims.get(name)
1305
+ if split_dim is None:
1306
+ if tensor.size(0) == out_features_total:
1307
+ split_dim = 0
1308
+ elif tensor.dim() > 1 and tensor.size(1) == out_features_total:
1309
+ split_dim = 1
1310
+ else:
1311
+ split_dim = 0
1312
+ handler = _get_split_handler(split_info, name, split_handlers)
1313
+ chunks = _call_split_handler(
1314
+ handler,
1315
+ src=tensor,
1316
+ dim=split_dim,
1317
+ split_sizes=split_sizes,
1318
+ context=dict(base_ctx, split_dim=split_dim),
1319
+ )
1320
+ if chunks is None:
1321
+ if tensor.dim() <= split_dim or tensor.size(split_dim) != out_features_total:
1322
+ got_size = "n/a" if tensor.dim() <= split_dim else tensor.size(split_dim)
1323
+ raise ValueError(
1324
+ f"Cannot split '{k}' quantized tensor '{name}': "
1325
+ f"expected size({split_dim})={out_features_total}, got {got_size}."
1326
+ )
1327
+ chunks = torch.split(tensor, split_sizes, dim=split_dim)
1328
+ for sub_map, chunk in zip(sub_maps, chunks):
1329
+ sub_map[name] = chunk
1330
+
1331
+ create_fn = getattr(weight.__class__, "create", None)
1332
+ if not callable(create_fn):
1333
+ raise ValueError(f"Quantized weight class '{weight.__class__.__name__}' has no create()")
1334
+ create_sig = inspect.signature(create_fn)
1335
+ base_kwargs = {
1336
+ "qtype": getattr(weight, "qtype", None),
1337
+ "axis": getattr(weight, "axis", None),
1338
+ "stride": weight.stride(),
1339
+ "dtype": weight.dtype,
1340
+ "activation_qtype": getattr(weight, "activation_qtype", None),
1341
+ "requires_grad": weight.requires_grad,
1342
+ "group_size": getattr(weight, "_group_size", None),
1343
+ "device": weight.device,
1344
+ }
1345
+
1346
+ qmodule_cls = module.__class__
1347
+ for sub_name, sub_size, sub_map, sub_bias in zip(
1348
+ mapped_modules, split_sizes, sub_maps, sub_biases
1349
+ ):
950
1350
  with init_empty_weights():
951
- sub_module = QLinear(_subdata.shape[1], _subdata.shape[0], bias=bias != None, device ="cpu", dtype=weight.dtype)
952
- sub_module.weight = torch.nn.Parameter(WeightQBytesTensor.create(weight.qtype, weight.axis, _subdata.size(), weight.stride(), _subdata, _subscale, activation_qtype=weight.activation_qtype, requires_grad=weight.requires_grad ))
953
- if bias != None:
954
- sub_module.bias = torch.nn.Parameter(_subbias)
1351
+ sub_module = qmodule_cls(
1352
+ in_features,
1353
+ sub_size,
1354
+ bias=bias is not None,
1355
+ device="cpu",
1356
+ dtype=weight.dtype,
1357
+ weights=module.weight_qtype,
1358
+ activations=module.activation_qtype,
1359
+ optimizer=module.optimizer,
1360
+ quantize_input=True,
1361
+ )
1362
+ size = list(weight.size())
1363
+ if size:
1364
+ size[0] = sub_size
1365
+ base_kwargs["size"] = tuple(size)
1366
+ create_kwargs = {}
1367
+ missing = []
1368
+ for name, param in create_sig.parameters.items():
1369
+ if name == "self":
1370
+ continue
1371
+ if name in sub_map:
1372
+ create_kwargs[name] = sub_map[name]
1373
+ elif name in base_kwargs and base_kwargs[name] is not None:
1374
+ create_kwargs[name] = base_kwargs[name]
1375
+ elif param.default is param.empty:
1376
+ missing.append(name)
1377
+ if missing:
1378
+ raise ValueError(
1379
+ f"Unable to rebuild quantized weight for '{k}.{sub_name}': "
1380
+ f"missing {missing}."
1381
+ )
1382
+ sub_weight = create_fn(**create_kwargs)
1383
+ sub_module.weight = torch.nn.Parameter(sub_weight, requires_grad=weight.requires_grad)
1384
+ if sub_bias is not None:
1385
+ sub_module.bias = torch.nn.Parameter(sub_bias)
955
1386
  sub_module.optimizer = module.optimizer
956
1387
  sub_module.weight_qtype = module.weight_qtype
1388
+ sub_module.activation_qtype = module.activation_qtype
957
1389
  setattr(parent_module, sub_name, sub_module)
958
- # del _data, _scale, _subdata, sub_d
959
1390
  else:
960
1391
  sub_data = torch.split(weight, split_sizes, dim=0)
961
- sub_bias = torch.split(bias, split_sizes, dim=0)
1392
+ sub_bias = torch.split(bias, split_sizes, dim=0) if bias is not None else [None] * len(split_sizes)
962
1393
  for sub_name, subdata, subbias in zip(mapped_modules, sub_data, sub_bias):
963
1394
  with init_empty_weights():
964
1395
  sub_module = torch.nn.Linear( subdata.shape[1], subdata.shape[0], bias=bias != None, device ="cpu", dtype=weight.dtype)
@@ -975,7 +1406,39 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
975
1406
 
976
1407
  loras_model_data = getattr(model, "_loras_model_data", None)
977
1408
  if loras_model_data == None:
978
- raise Exception(f"No Loras has been declared for this model while creating the corresponding offload object")
1409
+ merged_loras_model_data = {}
1410
+ merged_loras_shortcuts = {}
1411
+ sub_loras = {}
1412
+ for submodule_name, submodule in model.named_modules():
1413
+ if submodule is model:
1414
+ continue
1415
+ sub_model_data = getattr(submodule, "_loras_model_data", None)
1416
+ if sub_model_data:
1417
+ submodule._lora_owner = model
1418
+ sub_loras[submodule_name] = submodule
1419
+ for k, v in sub_model_data.items():
1420
+ if k not in merged_loras_model_data:
1421
+ merged_loras_model_data[k] = v
1422
+ sub_shortcuts = getattr(submodule, "_loras_model_shortcuts", None)
1423
+ if sub_shortcuts:
1424
+ prefix = f"{submodule_name}." if submodule_name else ""
1425
+ for k, v in sub_shortcuts.items():
1426
+ merged_key = k
1427
+ if prefix:
1428
+ if k:
1429
+ merged_key = f"{prefix}{k}"
1430
+ else:
1431
+ merged_key = submodule_name
1432
+ if merged_key not in merged_loras_shortcuts:
1433
+ merged_loras_shortcuts[merged_key] = v
1434
+ if merged_loras_model_data:
1435
+ model._loras_model_data = merged_loras_model_data
1436
+ if merged_loras_shortcuts:
1437
+ model._loras_model_shortcuts = merged_loras_shortcuts
1438
+ model._subloras = sub_loras
1439
+ loras_model_data = merged_loras_model_data
1440
+ else:
1441
+ raise Exception(f"No Loras has been declared for this model while creating the corresponding offload object")
979
1442
 
980
1443
  if not check_only:
981
1444
  unload_loras_from_model(model)
@@ -1027,7 +1490,7 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1027
1490
 
1028
1491
  if split_linear_modules_map != None:
1029
1492
  new_state_dict = dict()
1030
- suffixes = [(".alpha", -2, False), (".lora_B.weight", -3, True), (".lora_A.weight", -3, False)]
1493
+ suffixes = [(".alpha", -2, False), (".lora_B.weight", -3, True), (".lora_A.weight", -3, False), (".lora_up.weight", -3, True), (".lora_down.weight", -3, False),(".dora_scale", -2, False),]
1031
1494
  for module_name, module_data in state_dict.items():
1032
1495
  name_parts = module_name.split(".")
1033
1496
  for suffix, pos, any_split in suffixes:
@@ -1052,22 +1515,25 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1052
1515
 
1053
1516
  if not fail:
1054
1517
  pos = first_key.find(".")
1055
- prefix = first_key[0:pos]
1056
- if prefix not in ["diffusion_model", "transformer"]:
1057
- msg = f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format."
1058
- error_msg = append(error_msg, msg)
1059
- fail = True
1060
-
1061
- if not fail:
1518
+ prefix = first_key[0:pos+1]
1519
+ if prefix in ["diffusion_model.", "transformer."]:
1520
+ prefixes = ("diffusion_model.", "transformer.")
1521
+ new_state_dict = {}
1522
+ for k, v in state_dict.items():
1523
+ for candidate in prefixes:
1524
+ if k.startswith(candidate):
1525
+ k = k[len(candidate) :]
1526
+ break
1527
+ new_state_dict[k] = v
1528
+ state_dict = new_state_dict
1062
1529
 
1063
- state_dict = { k[ len(prefix) + 1:]: v for k, v in state_dict.items() if k.startswith(prefix) }
1064
1530
  clean_up = True
1065
1531
 
1066
1532
  keys = list(state_dict.keys())
1067
1533
 
1068
1534
  lora_alphas = {}
1069
1535
  for k in keys:
1070
- if "alpha" in k:
1536
+ if k.endswith(".alpha"):
1071
1537
  alpha_value = state_dict.pop(k)
1072
1538
  if torch.is_tensor(alpha_value):
1073
1539
  alpha_value = float(alpha_value.item())
@@ -1075,17 +1541,19 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1075
1541
 
1076
1542
  invalid_keys = []
1077
1543
  unexpected_keys = []
1078
- for k, v in state_dict.items():
1079
- lora_A = None
1080
- lora_B = None
1081
- diff_b = None
1082
- diff = None
1544
+ new_state_dict = {}
1545
+ for k in list(state_dict.keys()):
1546
+ v = state_dict.pop(k)
1547
+ lora_A = lora_B = diff_b = diff = lora_key = dora_scale = None
1083
1548
  if k.endswith(".diff"):
1084
1549
  diff = v
1085
1550
  module_name = k[ : -5]
1086
1551
  elif k.endswith(".diff_b"):
1087
1552
  diff_b = v
1088
1553
  module_name = k[ : -7]
1554
+ elif k.endswith(".dora_scale"):
1555
+ dora_scale = v
1556
+ module_name = k[ : -11]
1089
1557
  else:
1090
1558
  pos = k.rfind(".lora_")
1091
1559
  if pos <=0:
@@ -1118,30 +1586,33 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1118
1586
  if ignore_model_variations:
1119
1587
  skip = True
1120
1588
  else:
1121
- msg = f"Lora '{path}': Lora A dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[1]}, lora A = {v.shape[1]}). It is likely this Lora has been made for another version of this model."
1589
+ msg = f"Lora '{path}/{module_name}': Lora A dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[1]}, lora A = {v.shape[1]}). It is likely this Lora has been made for another version of this model."
1122
1590
  error_msg = append(error_msg, msg)
1123
1591
  fail = True
1124
1592
  break
1593
+ v = lora_A = lora_A.to(module.weight.dtype)
1125
1594
  elif lora_B != None:
1126
1595
  rank = lora_B.shape[1]
1127
1596
  if module_shape[0] != v.shape[0]:
1128
1597
  if ignore_model_variations:
1129
1598
  skip = True
1130
1599
  else:
1131
- msg = f"Lora '{path}': Lora B dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[0]}, lora B = {v.shape[0]}). It is likely this Lora has been made for another version of this model."
1600
+ msg = f"Lora '{path}/{module_name}': Lora B dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[0]}, lora B = {v.shape[0]}). It is likely this Lora has been made for another version of this model."
1132
1601
  error_msg = append(error_msg, msg)
1133
1602
  fail = True
1134
1603
  break
1604
+ v = lora_B = lora_B.to(module.weight.dtype)
1135
1605
  elif diff != None:
1136
1606
  lora_B = diff
1137
1607
  if module_shape != v.shape:
1138
1608
  if ignore_model_variations:
1139
1609
  skip = True
1140
1610
  else:
1141
- msg = f"Lora '{path}': Lora shape is not compatible with model '{_get_module_name(model)}' (model = {module_shape[0]}, lora = {v.shape[0]}). It is likely this Lora has been made for another version of this model."
1611
+ msg = f"Lora '{path}/{module_name}': Lora shape is not compatible with model '{_get_module_name(model)}' (model = {module_shape}, lora = {v.shape}). It is likely this Lora has been made for another version of this model."
1142
1612
  error_msg = append(error_msg, msg)
1143
1613
  fail = True
1144
1614
  break
1615
+ v = lora_B = lora_B.to(module.weight.dtype)
1145
1616
  elif diff_b != None:
1146
1617
  rank = diff_b.shape[0]
1147
1618
  if not hasattr(module, "bias"):
@@ -1160,26 +1631,42 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1160
1631
  error_msg = append(error_msg, msg)
1161
1632
  fail = True
1162
1633
  break
1163
-
1634
+ v = diff_b = diff_b.to(module.weight.dtype)
1635
+ elif dora_scale != None:
1636
+ rank = dora_scale.shape[1]
1637
+ if module_shape[0] != v.shape[0]:
1638
+ if ignore_model_variations:
1639
+ skip = True
1640
+ else:
1641
+ msg = f"Lora '{path}': Dora Scale dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[0]}, dora scale = {v.shape[0]}). It is likely this Dora has been made for another version of this model."
1642
+ error_msg = append(error_msg, msg)
1643
+ fail = True
1644
+ break
1645
+ v = dora_scale = dora_scale.to(module.weight.dtype)
1164
1646
  if not check_only:
1647
+ new_state_dict[k] = v
1648
+ v = None
1165
1649
  loras_module_data = loras_model_data.get(module, None)
1166
1650
  assert loras_module_data != None
1167
1651
  loras_adapter_data = loras_module_data.get(adapter_name, None)
1168
1652
  if loras_adapter_data == None:
1169
- loras_adapter_data = [None, None, None, 1.]
1653
+ loras_adapter_data = [None, None, None, None, 1.]
1654
+ module.any_dora = False
1170
1655
  loras_module_data[adapter_name] = loras_adapter_data
1171
1656
  if lora_A != None:
1172
- loras_adapter_data[0] = lora_A.to(module.weight.dtype)
1657
+ loras_adapter_data[0] = lora_A
1173
1658
  elif lora_B != None:
1174
- loras_adapter_data[1] = lora_B.to(module.weight.dtype)
1659
+ loras_adapter_data[1] = lora_B
1660
+ elif dora_scale != None:
1661
+ loras_adapter_data[3] = dora_scale
1662
+ loras_module_data["any_dora"] = True
1175
1663
  else:
1176
- loras_adapter_data[2] = diff_b.to(module.weight.dtype)
1177
- if rank != None:
1178
- alpha_key = k[:-len("lora_X.weight")] + "alpha"
1664
+ loras_adapter_data[2] = diff_b
1665
+ if rank != None and lora_key is not None and "lora" in lora_key:
1666
+ alpha_key = k[:-len(lora_key)] + "alpha"
1179
1667
  alpha = lora_alphas.get(alpha_key, None)
1180
- alpha = 1. if alpha == None else alpha / rank
1181
- loras_adapter_data[3] = alpha
1182
- lora_A = lora_B = diff = diff_b = v = loras_module_data = loras_adapter_data = lora_alphas = None
1668
+ if alpha is not None: loras_adapter_data[4] = alpha / rank
1669
+ lora_A = lora_B = diff = diff_b = v = loras_module_data = loras_adapter_data = lora_alphas = dora_scale = None
1183
1670
 
1184
1671
  if len(invalid_keys) > 0:
1185
1672
  msg = f"Lora '{path}' contains non Lora keys '{trunc(invalid_keys,200)}'"
@@ -1202,7 +1689,7 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1202
1689
  if not check_only:
1203
1690
  # model._loras_tied_weights[adapter_name] = tied_weights
1204
1691
  if pinnedLora:
1205
- pinned_sd_list.append(state_dict)
1692
+ pinned_sd_list.append(new_state_dict)
1206
1693
  pinned_names_list.append(path)
1207
1694
  # _pin_sd_to_memory(state_dict, path)
1208
1695
 
@@ -1250,6 +1737,7 @@ def sync_models_loras(model, model2):
1250
1737
 
1251
1738
  def unload_loras_from_model(model):
1252
1739
  if model is None: return
1740
+ if not hasattr(model, "_loras_model_data"): return
1253
1741
  for _, v in model._loras_model_data.items():
1254
1742
  v.clear()
1255
1743
  for _, v in model._loras_model_shortcuts.items():
@@ -1264,9 +1752,25 @@ def unload_loras_from_model(model):
1264
1752
 
1265
1753
 
1266
1754
  def set_step_no_for_lora(model, step_no):
1755
+ target = getattr(model, "_lora_owner", None)
1756
+ while target is not None and target is not model:
1757
+ model = target
1758
+ target = getattr(model, "_lora_owner", None)
1267
1759
  model._lora_step_no = step_no
1760
+ sub_loras = getattr(model, "_subloras", None)
1761
+ if sub_loras:
1762
+ submodules = sub_loras.values() if isinstance(sub_loras, dict) else sub_loras
1763
+ for submodule in submodules:
1764
+ if submodule is model:
1765
+ continue
1766
+ submodule._lora_step_no = step_no
1268
1767
 
1269
1768
  def activate_loras(model, lora_nos, lora_multi = None):
1769
+ target = getattr(model, "_lora_owner", None)
1770
+ while target is not None and target is not model:
1771
+ model = target
1772
+ target = getattr(model, "_lora_owner", None)
1773
+
1270
1774
  if not isinstance(lora_nos, list):
1271
1775
  lora_nos = [lora_nos]
1272
1776
  lora_nos = [str(l) for l in lora_nos]
@@ -1281,6 +1785,15 @@ def activate_loras(model, lora_nos, lora_multi = None):
1281
1785
  model._lora_step_no = 0
1282
1786
  model._loras_active_adapters = lora_nos
1283
1787
  model._loras_scaling = lora_scaling_dict
1788
+ sub_loras = getattr(model, "_subloras", None)
1789
+ if sub_loras:
1790
+ submodules = sub_loras.values() if isinstance(sub_loras, dict) else sub_loras
1791
+ for submodule in submodules:
1792
+ if submodule is model:
1793
+ continue
1794
+ submodule._lora_step_no = 0
1795
+ submodule._loras_active_adapters = lora_nos
1796
+ submodule._loras_scaling = lora_scaling_dict
1284
1797
 
1285
1798
 
1286
1799
  def move_loras_to_device(model, device="cpu" ):
@@ -1293,7 +1806,7 @@ def move_loras_to_device(model, device="cpu" ):
1293
1806
  if ".lora_" in k:
1294
1807
  m.to(device)
1295
1808
 
1296
- def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, preprocess_sd = None, modules = None, return_shared_modules = None, configKwargs ={}):
1809
+ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, preprocess_sd = None, modules = None, return_shared_modules = None, default_dtype = torch.bfloat16, ignore_unused_weights = False, configKwargs ={}):
1297
1810
  """
1298
1811
  quick version of .LoadfromPretrained of the transformers library
1299
1812
  used to build a model and load the corresponding weights (quantized or not)
@@ -1305,7 +1818,7 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantiza
1305
1818
  model_path = [model_path]
1306
1819
 
1307
1820
 
1308
- if not builtins.all(file_name.endswith(".sft") or file_name.endswith(".safetensors") or file_name.endswith(".pt") for file_name in model_path):
1821
+ if not builtins.all(file_name.endswith(".sft") or file_name.endswith(".safetensors") or file_name.endswith(".pt") or file_name.endswith(".ckpt") for file_name in model_path):
1309
1822
  raise Exception("full model path to file expected")
1310
1823
 
1311
1824
  model_path = [ _get_model(file) for file in model_path]
@@ -1313,7 +1826,7 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantiza
1313
1826
  raise Exception("Unable to find file")
1314
1827
 
1315
1828
  verboseLevel = _compute_verbose_level(verboseLevel)
1316
- if model_path[-1].endswith(".pt"):
1829
+ if model_path[-1].endswith(".pt") or model_path[-1].endswith(".ckpt"):
1317
1830
  metadata = None
1318
1831
  else:
1319
1832
  with safetensors2.safe_open(model_path[-1], writable_tensors =writable_tensors) as f:
@@ -1376,18 +1889,18 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantiza
1376
1889
  model = transfomer_class.from_config(transformer_config )
1377
1890
 
1378
1891
 
1379
- torch.set_default_device('cpu')
1380
1892
  model.eval().requires_grad_(False)
1381
1893
 
1382
1894
  model._config = transformer_config
1383
-
1384
- load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, preprocess_sd = preprocess_sd , modules = modules, return_shared_modules = return_shared_modules, verboseLevel=verboseLevel )
1895
+
1896
+ load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, preprocess_sd = preprocess_sd , modules = modules, return_shared_modules = return_shared_modules, default_dtype = default_dtype, ignore_unused_weights = ignore_unused_weights, verboseLevel=verboseLevel )
1385
1897
 
1386
1898
  return model
1387
1899
 
1388
1900
 
1389
1901
 
1390
- def load_model_data(model, file_path, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, preprocess_sd = None, modules = None, return_shared_modules = None, verboseLevel = -1):
1902
+ @cudacontext("cpu")
1903
+ def load_model_data(model, file_path, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, preprocess_sd = None, postprocess_sd = None, modules = None, return_shared_modules = None, default_dtype = torch.bfloat16, ignore_unused_weights = False, verboseLevel = -1, ignore_missing_keys = False):
1391
1904
  """
1392
1905
  Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
1393
1906
  """
@@ -1423,8 +1936,14 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
1423
1936
  file_path += modules
1424
1937
  modules = None
1425
1938
 
1426
- file_path = [ _get_model(file) for file in file_path]
1427
- if any( file == None for file in file_path):
1939
+ normalized_paths = []
1940
+ for file in file_path:
1941
+ if isinstance(file, (dict, tuple)):
1942
+ normalized_paths.append(file)
1943
+ else:
1944
+ normalized_paths.append(_get_model(file))
1945
+ file_path = normalized_paths
1946
+ if any(file is None for file in file_path):
1428
1947
  raise Exception("Unable to find file")
1429
1948
  verboseLevel = _compute_verbose_level(verboseLevel)
1430
1949
 
@@ -1442,18 +1961,24 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
1442
1961
  for no, file in enumerate(file_path):
1443
1962
  quantization_map = None
1444
1963
  tied_weights_map = None
1445
- if not (".safetensors" in file or ".sft" in file):
1964
+ metadata = None
1965
+ detected_kind = None
1966
+ if isinstance(file, tuple):
1967
+ if len(file) != 2:
1968
+ raise Exception("Expected a tuple of (state_dict, quantization_map)")
1969
+ state_dict, quantization_map = file
1970
+ elif isinstance(file, dict):
1971
+ state_dict = file
1972
+ elif not (".safetensors" in file or ".sft" in file):
1446
1973
  if pinToMemory:
1447
1974
  raise Exception("Pinning to memory while loading only supported for safe tensors files")
1448
- state_dict = torch.load(file, weights_only=True, map_location="cpu")
1975
+ state_dict = torch.load(file, weights_only=False, map_location="cpu")
1449
1976
  if "module" in state_dict:
1450
1977
  state_dict = state_dict["module"]
1451
-
1452
1978
  else:
1453
1979
  basename = os.path.basename(file)
1454
1980
 
1455
1981
  if "-of-" in basename:
1456
- metadata = None
1457
1982
  file_parts= basename.split("-")
1458
1983
  parts_max = int(file_parts[-1][:5])
1459
1984
  state_dict = {}
@@ -1463,29 +1988,50 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
1463
1988
  state_dict.update(sd)
1464
1989
  else:
1465
1990
  state_dict, metadata = _safetensors_load_file(file, writable_tensors =writable_tensors)
1466
-
1467
- if metadata != None:
1468
- quantization_map = metadata.get("quantization_map", None)
1469
- config = metadata.get("config", None)
1470
- if config is not None:
1471
- model._config = config
1472
-
1473
- tied_weights_map = metadata.get("tied_weights_map", None)
1474
- if tied_weights_map != None:
1475
- for name, tied_weights_list in tied_weights_map.items():
1476
- mapped_weight = state_dict[name]
1477
- for tied_weights in tied_weights_list:
1478
- state_dict[tied_weights] = mapped_weight
1479
-
1480
- if quantization_map is None:
1481
- pos = str.rfind(file, ".")
1482
- if pos > 0:
1483
- quantization_map_path = file[:pos]
1484
- quantization_map_path += "_map.json"
1485
-
1486
- if os.path.isfile(quantization_map_path):
1487
- with open(quantization_map_path, 'r') as f:
1488
- quantization_map = json.load(f)
1991
+
1992
+ if preprocess_sd != None:
1993
+ state_dict = preprocess_sd(state_dict)
1994
+
1995
+ if metadata != None:
1996
+ quantization_map = metadata.get("quantization_map", None)
1997
+ config = metadata.get("config", None)
1998
+ if config is not None:
1999
+ model._config = config
2000
+
2001
+ tied_weights_map = metadata.get("tied_weights_map", None)
2002
+ if tied_weights_map != None:
2003
+ for name, tied_weights_list in tied_weights_map.items():
2004
+ mapped_weight = state_dict[name]
2005
+ for tied_weights in tied_weights_list:
2006
+ state_dict[tied_weights] = mapped_weight
2007
+
2008
+ if quantization_map is None and isinstance(file, str):
2009
+ pos = str.rfind(file, ".")
2010
+ if pos > 0:
2011
+ quantization_map_path = file[:pos]
2012
+ quantization_map_path += "_map.json"
2013
+
2014
+ if os.path.isfile(quantization_map_path):
2015
+ with open(quantization_map_path, 'r') as f:
2016
+ quantization_map = json.load(f)
2017
+
2018
+ if quantization_map is None:
2019
+ conv_result = detect_and_convert(state_dict, default_dtype=default_dtype, verboseLevel=verboseLevel)
2020
+ detected_kind = conv_result.get("kind")
2021
+ if conv_result.get("kind") not in ("none", "quanto"):
2022
+ state_dict = conv_result["state_dict"]
2023
+ quantization_map = conv_result["quant_map"]
2024
+ conv_result = None
2025
+ # enable_fp8_fp32_scale_support()
2026
+
2027
+ if detected_kind in (None, "none") and isinstance(file, str) and (".safetensors" in file or ".sft" in file):
2028
+ try:
2029
+ info = detect_safetensors_format(state_dict, verboseLevel=verboseLevel)
2030
+ detected_kind = info.get("kind")
2031
+ except Exception:
2032
+ detected_kind = detected_kind or None
2033
+ if detected_kind not in (None, "none") and isinstance(file, str):
2034
+ cache_quantization_for_file(file, detected_kind or "none")
1489
2035
 
1490
2036
  full_state_dict.update(state_dict)
1491
2037
  if quantization_map != None:
@@ -1504,8 +2050,8 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
1504
2050
  full_state_dict, full_quantization_map, full_tied_weights_map = None, None, None
1505
2051
 
1506
2052
  # deal if we are trying to load just a sub part of a larger model
1507
- if preprocess_sd != None:
1508
- state_dict, quantization_map = preprocess_sd(state_dict, quantization_map)
2053
+ if postprocess_sd != None:
2054
+ state_dict, quantization_map = postprocess_sd(state_dict, quantization_map)
1509
2055
 
1510
2056
  if modelPrefix != None:
1511
2057
  base_model_prefix = modelPrefix + "."
@@ -1513,11 +2059,21 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
1513
2059
  if quantization_map != None:
1514
2060
  quantization_map = filter_state_dict(quantization_map,base_model_prefix)
1515
2061
 
2062
+ post_load_hooks = []
2063
+ if quantization_map:
2064
+ quantization_map, post_load_hooks = apply_pre_quantization(
2065
+ model,
2066
+ state_dict,
2067
+ quantization_map,
2068
+ default_dtype=default_dtype,
2069
+ verboseLevel=verboseLevel,
2070
+ )
2071
+
1516
2072
  if len(quantization_map) == 0:
1517
- if any("quanto" in file for file in file_path) and not do_quantize:
2073
+ if any(isinstance(file, str) and "quanto" in file for file in file_path) and not do_quantize:
1518
2074
  print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
1519
2075
  else:
1520
- _requantize(model, state_dict, quantization_map)
2076
+ _requantize(model, state_dict, quantization_map, default_dtype=default_dtype)
1521
2077
 
1522
2078
 
1523
2079
 
@@ -1530,13 +2086,25 @@ def load_model_data(model, file_path, do_quantize = False, quantizationType = qi
1530
2086
  base_model_prefix = k[:-len(missing_keys[0])]
1531
2087
  break
1532
2088
  if base_model_prefix == None:
1533
- raise Exception(f"Missing keys: {missing_keys}")
1534
- state_dict = filter_state_dict(state_dict, base_model_prefix)
1535
- missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
2089
+ if not ignore_missing_keys:
2090
+ raise Exception(f"Missing keys: {missing_keys}")
2091
+ else:
2092
+ state_dict = filter_state_dict(state_dict, base_model_prefix)
2093
+ missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
2094
+ if len(missing_keys) > 0 and not ignore_missing_keys:
2095
+ raise Exception(f"Missing keys: {missing_keys}")
1536
2096
 
1537
2097
  del state_dict
1538
2098
 
1539
- if len(unexpected_keys) > 0 and verboseLevel >=2:
2099
+ if post_load_hooks:
2100
+ for hook in post_load_hooks:
2101
+ try:
2102
+ hook(model)
2103
+ except Exception as e:
2104
+ if verboseLevel >= 2:
2105
+ print(f"Post-load hook skipped: {e}")
2106
+
2107
+ if len(unexpected_keys) > 0 and verboseLevel >=2 and not ignore_unused_weights:
1540
2108
  print(f"Unexpected keys while loading '{file_path}': {unexpected_keys}")
1541
2109
 
1542
2110
  for k,p in model.named_parameters():
@@ -1728,6 +2296,35 @@ class HfHook:
1728
2296
  def detach_hook(self, module):
1729
2297
  return module
1730
2298
 
2299
+ def _mm_lora_linear_forward(module, *args, **kwargs):
2300
+ loras_data = getattr(module, "_mm_lora_data", None)
2301
+ if not loras_data:
2302
+ return module._mm_lora_old_forward(*args, **kwargs)
2303
+ if not hasattr(module, "_mm_manager"):
2304
+ pass
2305
+ return module._mm_manager._lora_linear_forward(
2306
+ module._mm_lora_model,
2307
+ module,
2308
+ loras_data,
2309
+ *args,
2310
+ **kwargs,
2311
+ )
2312
+
2313
+
2314
+ def _mm_lora_generic_forward(module, *args, **kwargs):
2315
+ loras_data = getattr(module, "_mm_lora_data", None)
2316
+ if not loras_data:
2317
+ return module._mm_lora_old_forward(*args, **kwargs)
2318
+ return module._mm_manager._lora_generic_forward(
2319
+ module._mm_lora_model,
2320
+ module,
2321
+ loras_data,
2322
+ module._mm_lora_old_forward,
2323
+ *args,
2324
+ **kwargs,
2325
+ )
2326
+
2327
+
1731
2328
  last_offload_obj = None
1732
2329
  class offload:
1733
2330
  def __init__(self):
@@ -1757,6 +2354,7 @@ class offload:
1757
2354
  global last_offload_obj
1758
2355
  last_offload_obj = self
1759
2356
 
2357
+ self._type_wrappers = {}
1760
2358
 
1761
2359
  def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name, submodule_name):
1762
2360
 
@@ -1781,22 +2379,12 @@ class offload:
1781
2379
  param_size = 0
1782
2380
  ref = _get_tensor_ref(p)
1783
2381
  tied_param = self.parameters_ref.get(ref, None)
1784
- if isinstance(p, QTensor):
1785
- blocks_params.append( (submodule, k, p, False, tied_param ) )
1786
-
1787
- if p._qtype == qint4:
1788
- if hasattr(p,"_scale_shift"):
1789
- param_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
1790
- param_size += torch.numel(p._data._data) * p._data._data.element_size()
1791
- else:
1792
- param_size += torch.numel(p._scale) * p._scale.element_size()
1793
- param_size += torch.numel(p._shift) * p._shift.element_size()
1794
- param_size += torch.numel(p._data._data) * p._data._data.element_size()
1795
- else:
1796
- param_size += torch.numel(p._scale) * p._scale.element_size()
1797
- param_size += torch.numel(p._data) * p._data.element_size()
2382
+ blocks_params.append((submodule, k, p, False, tied_param))
2383
+ sub_tensors = _get_quantized_subtensors(p)
2384
+ if sub_tensors:
2385
+ param_size += _subtensors_nbytes(sub_tensors)
2386
+ del sub_tensors
1798
2387
  else:
1799
- blocks_params.append( (submodule, k, p, False, tied_param) )
1800
2388
  param_size += torch.numel(p.data) * p.data.element_size()
1801
2389
 
1802
2390
 
@@ -2091,7 +2679,7 @@ class offload:
2091
2679
  data = loras_data.get(active_adapter + '_GPU', None)
2092
2680
  if data == None:
2093
2681
  continue
2094
- diff_w , _ , diff_b, alpha = data
2682
+ diff_w , _ , diff_b, _, alpha = data
2095
2683
  scaling = self._get_lora_scaling( loras_scaling, model, active_adapter) * alpha
2096
2684
  if scaling == 0:
2097
2685
  continue
@@ -2117,15 +2705,117 @@ class offload:
2117
2705
  return ret
2118
2706
 
2119
2707
 
2708
+ def _dora_linear_forward(
2709
+ self,
2710
+ model,
2711
+ submodule,
2712
+ adapters_data, # dict: name+"_GPU" -> (A, B, diff_b, g_abs, alpha); g_abs=None means LoRA
2713
+ weight= None,
2714
+ bias = None,
2715
+ original_bias = True,
2716
+ dora_mode: str = "blend", # "ref_exact" | "blend"
2717
+ ):
2718
+ active_adapters = getattr(model, "_loras_active_adapters", [])
2719
+ loras_scaling = getattr(model, "_loras_scaling", {})
2720
+ # Snapshot base weight (safe for quantized modules)
2721
+ if weight is None:
2722
+ bias = submodule.bias
2723
+ original_bias = True
2724
+ if isinstance(submodule, QModuleMixin):
2725
+ weight = submodule.weight.view(submodule.weight.shape)
2726
+ else:
2727
+ weight = submodule.weight.clone()
2728
+
2729
+ base_dtype = weight.dtype
2730
+ eps = 1e-8
2731
+ W0 = weight.float()
2732
+ g0 = torch.linalg.vector_norm(W0, dim=1, keepdim=True, dtype=torch.float32).clamp_min(eps) # [out,1]
2733
+
2734
+ # Keep big mats in low precision
2735
+ # Wc = W0 if W0.dtype == compute_dtype else W0.to(compute_dtype)
2736
+ W0 /= g0
2737
+ weight[...] = W0.to(base_dtype)
2738
+ W0 = None
2739
+
2740
+ dir_update = None # Σ s * ((B@A)/g0) in compute_dtype
2741
+ g = None # final magnitude: set absolute (ref_exact) or blended (blend)
2742
+ bias_delta = None # Σ s * diff_b
2743
+
2744
+ # Accumulate DoRA adapters only (g_abs != None)
2745
+ for name in active_adapters:
2746
+ data = adapters_data.get(name + "_GPU", None)
2747
+ if data is None: continue
2748
+ A, B, diff_b, g_abs, alpha = data
2749
+ if g_abs is None: continue
2750
+
2751
+ s = self._get_lora_scaling(loras_scaling, model, name) * float(alpha)
2752
+ if s == 0: continue
2753
+
2754
+ # Direction update in V-space with row-wise 1/g0
2755
+ if (A is not None) and (B is not None):
2756
+ dV = torch.mm(B, A) # [out,in], compute_dtype
2757
+ dV /= g0 # row-wise divide
2758
+ dV.mul_(s)
2759
+ dir_update = dV if dir_update is None else dir_update.add_(dV)
2760
+
2761
+
2762
+ if dora_mode == "ref_exact":
2763
+ # absolute magnitude (last one wins if multiple DoRAs present)
2764
+ g = g_abs
2765
+ elif dora_mode == "blend":
2766
+ # blend towards absolute magnitude proportional to s
2767
+ if g is None:
2768
+ g = g0.clone()
2769
+ g.add_(g_abs.sub(g0), alpha=s)
2770
+ else:
2771
+ raise ValueError(f"Unknown dora_mode: {dora_mode}")
2772
+
2773
+ # Optional bias deltas (not in reference, but harmless if present)
2774
+ if diff_b is not None:
2775
+ db = diff_b.mul(s)
2776
+ bias_delta = db if bias_delta is None else bias_delta.add_(db)
2777
+ db = None
2778
+
2779
+ if g is None:
2780
+ g = g0 # no magnitude provided -> keep original
2781
+
2782
+ # Re-normalize rows if we changed direction
2783
+ if dir_update is not None:
2784
+ weight.add_(dir_update)
2785
+ V = weight.float()
2786
+ Vn = torch.linalg.vector_norm(V, dim=1, keepdim=True, dtype=torch.float32).clamp_min(eps)
2787
+ V /= Vn
2788
+ V *= g
2789
+ weight[...] = V.to(base_dtype)
2790
+ V = None
2791
+ else:
2792
+ weight *= g
2793
+ # Recompose adapted weight; cast back to module dtype
2794
+
2795
+ # Merge DoRA bias delta safely
2796
+ if bias_delta is not None:
2797
+ if bias is None:
2798
+ bias = bias_delta
2799
+ else:
2800
+ bias = bias.clone() if original_bias else bias
2801
+ bias.add_(bias_delta)
2802
+
2803
+ return weight, bias
2804
+
2805
+
2806
+
2120
2807
  def _lora_linear_forward(self, model, submodule, loras_data, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
2121
2808
  weight = submodule.weight
2809
+ bias = submodule.bias
2122
2810
  active_adapters = model._loras_active_adapters
2123
2811
  loras_scaling = model._loras_scaling
2812
+ any_dora = loras_data.get("any_dora", False)
2813
+ is_nvfp4 = getattr(submodule, "is_nvfp4", False)
2124
2814
  training = False
2125
2815
 
2126
- dtype = weight.dtype
2127
- if weight.shape[-1] < x.shape[-2]: # sum base weight and lora matrices instead of applying input on each sub lora matrice if input is too large. This will save a lot VRAM and compute
2128
- bias = submodule.bias
2816
+ dtype = weight.dtype
2817
+ if (weight.shape[-1] < x.shape[-2] and False or any_dora): # sum base weight and lora matrices instead of applying input on each sub lora matrice if input is too large. This will save a lot VRAM and compute
2818
+ original_bias = True
2129
2819
  original_bias = True
2130
2820
  if len(active_adapters) > 0:
2131
2821
  if isinstance(submodule, QModuleMixin):
@@ -2136,10 +2826,17 @@ class offload:
2136
2826
  data = loras_data.get(active_adapter + '_GPU', None)
2137
2827
  if data == None:
2138
2828
  continue
2139
- lora_A_weight, lora_B_weight, diff_b, alpha = data
2829
+ lora_A_weight, lora_B_weight, diff_b, g_abs, alpha = data
2140
2830
  scaling = self._get_lora_scaling(loras_scaling, model, active_adapter) * alpha
2141
- if scaling == 0:
2831
+ if scaling == 0 or g_abs is not None:
2142
2832
  continue
2833
+ target_dtype = weight.dtype
2834
+ if lora_A_weight is not None and lora_A_weight.dtype != target_dtype:
2835
+ lora_A_weight = lora_A_weight.to(target_dtype)
2836
+ if lora_B_weight is not None and lora_B_weight.dtype != target_dtype:
2837
+ lora_B_weight = lora_B_weight.to(target_dtype)
2838
+ if diff_b is not None and diff_b.dtype != target_dtype:
2839
+ diff_b = diff_b.to(target_dtype)
2143
2840
  if lora_A_weight != None:
2144
2841
  weight.addmm_(lora_B_weight, lora_A_weight, alpha= scaling )
2145
2842
 
@@ -2152,41 +2849,60 @@ class offload:
2152
2849
  original_bias = False
2153
2850
  bias.add_(diff_b, alpha=scaling)
2154
2851
  # base_weight += scaling * lora_B_weight @ lora_A_weight
2852
+
2853
+ if any_dora :
2854
+ weight, bias = self._dora_linear_forward(model, submodule, loras_data, weight, bias, original_bias)
2155
2855
  if training:
2156
2856
  pass
2157
2857
  # result = torch.nn.functional.linear(dropout(x), base_weight, bias=submodule.bias)
2158
2858
  else:
2159
- result = torch.nn.functional.linear(x, weight, bias=bias)
2859
+ base_bias = bias
2860
+ if base_bias is not None and base_bias.dtype != x.dtype:
2861
+ base_bias = base_bias.to(x.dtype)
2862
+ result = torch.nn.functional.linear(x, weight, bias=base_bias)
2160
2863
 
2161
2864
  else:
2162
- result = torch.nn.functional.linear(x, weight, bias=submodule.bias)
2865
+ base_bias = bias
2866
+ if base_bias is not None and base_bias.dtype != x.dtype:
2867
+ base_bias = base_bias.to(x.dtype)
2868
+ result = torch.nn.functional.linear(x, weight, bias=base_bias)
2163
2869
 
2164
2870
  if len(active_adapters) > 0:
2165
- x = x.to(dtype)
2871
+ compute_dtype = torch.float32 if is_nvfp4 else result.dtype
2872
+ if result.dtype != compute_dtype:
2873
+ result = result.to(compute_dtype)
2874
+ x = x.to(compute_dtype)
2166
2875
 
2167
2876
  for active_adapter in active_adapters:
2168
2877
  data = loras_data.get(active_adapter + '_GPU', None)
2169
2878
  if data == None:
2170
2879
  continue
2171
- lora_A, lora_B, diff_b, alpha = data
2880
+ lora_A, lora_B, diff_b, g_abs, alpha = data
2172
2881
  # dropout = self.lora_dropout[active_adapter]
2173
2882
  scaling = self._get_lora_scaling(loras_scaling, model, active_adapter) * alpha
2174
- if scaling == 0:
2883
+ if scaling == 0 or g_abs is not None:
2175
2884
  continue
2885
+ target_dtype = result.dtype
2886
+ if lora_A is not None and lora_A.dtype != target_dtype:
2887
+ lora_A = lora_A.to(target_dtype)
2888
+ if lora_B is not None and lora_B.dtype != target_dtype:
2889
+ lora_B = lora_B.to(target_dtype)
2890
+ if diff_b is not None and diff_b.dtype != target_dtype:
2891
+ diff_b = diff_b.to(target_dtype)
2892
+
2176
2893
  if lora_A == None:
2177
2894
  result.add_(diff_b, alpha=scaling)
2178
2895
  else:
2179
- x = x.to(lora_A.dtype)
2180
-
2181
- if training:
2182
- pass
2183
- # y = lora_A(dropout(x))
2184
- else:
2185
- y = torch.nn.functional.linear(x, lora_A, bias=None)
2186
- y = torch.nn.functional.linear(y, lora_B, bias=diff_b)
2187
- y*= scaling
2188
- result+= y
2896
+ x_2d = x.reshape(-1, x.shape[-1])
2897
+ result_2d = result.reshape(-1, result.shape[-1])
2898
+ y = x_2d @ lora_A.T
2899
+ result_2d.addmm_(y, lora_B.T, beta=1, alpha=scaling)
2900
+ if diff_b is not None:
2901
+ result_2d.add_(diff_b, alpha=scaling)
2189
2902
  del y
2903
+ target_dtype = input_dtype if is_nvfp4 else dtype
2904
+ if result.dtype != target_dtype:
2905
+ result = result.to(target_dtype)
2190
2906
 
2191
2907
  return result
2192
2908
 
@@ -2198,22 +2914,14 @@ class offload:
2198
2914
  assert submodule_name not in loras_model_shortcuts
2199
2915
  loras_model_shortcuts[submodule_name] = loras_data
2200
2916
  loras_model_data[submodule] = loras_data
2917
+ submodule._mm_lora_data = loras_data
2918
+ submodule._mm_lora_model = current_model
2919
+ submodule._mm_lora_old_forward = old_forward
2201
2920
 
2202
- if isinstance(submodule, torch.nn.Linear):
2203
- def lora_linear_forward(module, *args, **kwargs):
2204
- if len(loras_data) == 0:
2205
- return old_forward(*args, **kwargs)
2206
- else:
2207
- submodule.aaa = submodule_name
2208
- return self._lora_linear_forward(current_model, submodule, loras_data, *args, **kwargs)
2209
- target_fn = lora_linear_forward
2921
+ if isinstance(submodule, torch.nn.Linear) or getattr(submodule, "is_nvfp4", False):
2922
+ target_fn = _mm_lora_linear_forward
2210
2923
  else:
2211
- def lora_generic_forward(module, *args, **kwargs):
2212
- if len(loras_data) == 0:
2213
- return old_forward(*args, **kwargs)
2214
- else:
2215
- return self._lora_generic_forward(current_model, submodule, loras_data, old_forward, *args, **kwargs)
2216
- target_fn = lora_generic_forward
2924
+ target_fn = _mm_lora_generic_forward
2217
2925
  return functools.update_wrapper(functools.partial(target_fn, submodule), old_forward)
2218
2926
 
2219
2927
  def ensure_model_loaded(self, model_id):
@@ -2236,10 +2944,65 @@ class offload:
2236
2944
 
2237
2945
  # need to be registered before the forward not to be break the efficiency of the compilation chain
2238
2946
  # it should be at the top of the compilation as this type of hook in the middle of a chain seems to break memory performance
2239
- target_module.register_forward_pre_hook(preload_blocks_for_compile)
2947
+ target_module.register_forward_pre_hook(preload_blocks_for_compile)
2948
+
2240
2949
 
2241
2950
 
2242
- def hook_check_empty_cache_needed(self, target_module, model, model_id, blocks_name, previous_method, context):
2951
+
2952
+ @torch._dynamo.disable
2953
+ def _pre_check(self, module):
2954
+ model_id = getattr(module, "_mm_model_id", None)
2955
+ blocks_name = getattr(module, "_mm_blocks_name", None)
2956
+
2957
+ self.ensure_model_loaded(model_id)
2958
+ if blocks_name is None:
2959
+ if self.ready_to_check_mem():
2960
+ self.empty_cache_if_needed()
2961
+ elif blocks_name != self.loaded_blocks[model_id] and \
2962
+ blocks_name not in self.preloaded_blocks_per_model[model_id]:
2963
+ self.gpu_load_blocks(model_id, blocks_name)
2964
+
2965
+ def _get_wrapper_for_type(self, mod_cls):
2966
+ fn = self._type_wrappers.get(mod_cls)
2967
+ if fn is not None:
2968
+ return fn
2969
+
2970
+ # Unique function name per class -> unique compiled code object
2971
+ fname = f"_mm_wrap_{mod_cls.__module__.replace('.', '_')}_{mod_cls.__name__}"
2972
+
2973
+ # Keep body minimal; all heavy/offload logic runs out-of-graph in _pre_check
2974
+ # Include __TYPE_CONST in the code so the bytecode/consts differ per class.
2975
+ src = f"""
2976
+ def {fname}(module, *args, **kwargs):
2977
+ _ = __TYPE_CONST # anchor type as a constant to make code object unique per class
2978
+ nada = "{fname}"
2979
+ mgr = module._mm_manager
2980
+ mgr._pre_check(module)
2981
+ return module._mm_forward(*args, **kwargs) #{fname}
2982
+ """
2983
+ ns = {"__TYPE_CONST": mod_cls}
2984
+ exec(src, ns) # compile a new function object/code object for this class
2985
+ fn = ns[fname]
2986
+ self._type_wrappers[mod_cls] = fn
2987
+ return fn
2988
+
2989
+ def hook_check_load_into_GPU_if_needed(
2990
+ self, target_module, model, model_id, blocks_name, previous_method, context
2991
+ ):
2992
+ # store instance data on the module (not captured by the wrapper)
2993
+ target_module._mm_manager = self
2994
+ target_module._mm_model_id = model_id
2995
+ target_module._mm_blocks_name = blocks_name
2996
+ target_module._mm_forward = previous_method
2997
+
2998
+ # per-TYPE wrapper (unique bytecode per class, reused across instances of that class)
2999
+ wrapper_fn = self._get_wrapper_for_type(type(target_module))
3000
+
3001
+ # bind as a bound method (no partial/closures)
3002
+ # target_module.forward = types.MethodType(wrapper_fn, target_module)
3003
+ target_module.forward = functools.update_wrapper(functools.partial(wrapper_fn, target_module), previous_method)
3004
+
3005
+ def hook_check_load_into_GPU_if_needed_default(self, target_module, model, model_id, blocks_name, previous_method, context):
2243
3006
 
2244
3007
  dtype = model._dtype
2245
3008
  qint4quantization = isinstance(target_module, QModuleMixin) and target_module.weight!= None and target_module.weight.qtype == qint4
@@ -2259,22 +3022,33 @@ class offload:
2259
3022
  target_module.forward = target_module._mm_forward
2260
3023
  return
2261
3024
 
2262
- def check_empty_cuda_cache(module, *args, **kwargs):
3025
+ def check_load_into_GPU_needed():
2263
3026
  self.ensure_model_loaded(model_id)
2264
3027
  if blocks_name == None:
2265
3028
  if self.ready_to_check_mem():
2266
3029
  self.empty_cache_if_needed()
2267
3030
  elif blocks_name != self.loaded_blocks[model_id] and blocks_name not in self.preloaded_blocks_per_model[model_id]:
2268
3031
  self.gpu_load_blocks(model_id, blocks_name)
2269
- if qint4quantization and dtype !=None:
2270
- args, kwargs = self.move_args_to_gpu(dtype, *args, **kwargs)
2271
-
2272
- return previous_method(*args, **kwargs)
3032
+ # if qint4quantization and dtype !=None:
3033
+ # args, kwargs = self.move_args_to_gpu(dtype, *args, **kwargs)
3034
+
3035
+ if isinstance(target_module, torch.nn.Linear):
3036
+ def check_load_into_GPU_needed_linear(module, *args, **kwargs):
3037
+ check_load_into_GPU_needed()
3038
+ return previous_method(*args, **kwargs) # linear
3039
+ check_load_into_GPU_needed_module = check_load_into_GPU_needed_linear
3040
+ else:
3041
+ def check_load_into_GPU_needed_other(module, *args, **kwargs):
3042
+ check_load_into_GPU_needed()
3043
+ return previous_method(*args, **kwargs) # other
3044
+ check_load_into_GPU_needed_module = check_load_into_GPU_needed_other
2273
3045
 
2274
3046
  setattr(target_module, "_mm_id", model_id)
3047
+ setattr(target_module, "_mm_manager", self)
2275
3048
  setattr(target_module, "_mm_forward", previous_method)
2276
3049
 
2277
- setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method) )
3050
+ setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_load_into_GPU_needed_module, target_module), previous_method) )
3051
+ # target_module.register_forward_pre_hook(check_empty_cuda_cache)
2278
3052
 
2279
3053
 
2280
3054
  def hook_change_module(self, target_module, model, model_id, module_id, previous_method, previous_method_name ):
@@ -2300,7 +3074,7 @@ class offload:
2300
3074
  if not self.verboseLevel >=1:
2301
3075
  return
2302
3076
 
2303
- if module_id == None or module_id =='':
3077
+ if previous_method_name =="forward" and (module_id == None or module_id ==''):
2304
3078
  model_name = model._get_name()
2305
3079
  print(f"Hooked to model '{model_id}' ({model_name})")
2306
3080
 
@@ -2415,7 +3189,7 @@ class offload:
2415
3189
 
2416
3190
 
2417
3191
 
2418
- def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, partialPinning = False, loras = None, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, workingVRAM = None, asyncTransfers = True, compile = False, convertWeightsFloatTo = torch.bfloat16, perc_reserved_mem_max = 0, coTenantsMap = None, verboseLevel = -1):
3192
+ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, partialPinning = False, loras = None, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, workingVRAM = None, asyncTransfers = True, compile = False, convertWeightsFloatTo = torch.bfloat16, perc_reserved_mem_max = 0, coTenantsMap = None, vram_safety_coefficient = 0.8, compile_mode ="default", verboseLevel = -1):
2419
3193
  """Hook to a pipeline or a group of modules in order to reduce their VRAM requirements:
2420
3194
  pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
2421
3195
  quantizeTransformer: set True by default will quantize on the fly the video / image model
@@ -2424,6 +3198,8 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2424
3198
  budgets: 0 by default (unlimited). If non 0, it corresponds to the maximum size in MB that every model will occupy at any moment
2425
3199
  (in fact the real usage is twice this number). It is very efficient to reduce VRAM consumption but this feature may be very slow
2426
3200
  if pinnedMemory is not enabled
3201
+ vram_safety_coefficient: float between 0 and 1 (exclusive), default 0.8. Sets the maximum portion of VRAM that can be used for models.
3202
+ Lower values provide more safety margin but may reduce performance.
2427
3203
  """
2428
3204
  self = offload()
2429
3205
  self.verboseLevel = verboseLevel
@@ -2439,7 +3215,11 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2439
3215
  return float(b[:-1]) * self.device_mem_capacity
2440
3216
  else:
2441
3217
  return b * ONE_MB
2442
-
3218
+
3219
+ # Validate vram_safety_coefficient
3220
+ if not isinstance(vram_safety_coefficient, float) or vram_safety_coefficient <= 0 or vram_safety_coefficient >= 1:
3221
+ raise ValueError("vram_safety_coefficient must be a float between 0 and 1 (exclusive)")
3222
+
2443
3223
  budget = 0
2444
3224
  if not budgets is None:
2445
3225
  if isinstance(budgets , dict):
@@ -2523,26 +3303,22 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2523
3303
 
2524
3304
  current_model_size = 0
2525
3305
  model_dtype = getattr(current_model, "_model_dtype", None)
3306
+ fp8_fallback_dtype = None
2526
3307
  # if model_dtype == None:
2527
3308
  # model_dtype = getattr(current_model, "dtype", None)
2528
3309
  for _ , m in current_model.named_modules():
2529
3310
  ignore_dtype = hasattr(m, "_lock_dtype")
2530
3311
  for n, p in m.named_parameters(recurse = False):
2531
3312
  p.requires_grad = False
2532
- if isinstance(p, QTensor):
2533
- if p._qtype == qint4:
2534
- if hasattr(p,"_scale_shift"):
2535
- current_model_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
2536
- else:
2537
- current_model_size += torch.numel(p._scale) * p._shift.element_size() + torch.numel(p._scale) * p._shift.element_size()
2538
-
2539
- current_model_size += torch.numel(p._data._data) * p._data._data.element_size()
2540
-
2541
- else:
2542
- current_model_size += torch.numel(p._scale) * p._scale.element_size()
2543
- current_model_size += torch.numel(p._data) * p._data.element_size()
2544
- dtype = p._scale.dtype
2545
-
3313
+ sub_tensors = _get_quantized_subtensors(p)
3314
+ if sub_tensors:
3315
+ current_model_size += _subtensors_nbytes(sub_tensors)
3316
+ dtype = sub_tensors[0][1].dtype
3317
+ for name, tensor in sub_tensors:
3318
+ if name in ("scale", "scale_shift"):
3319
+ dtype = tensor.dtype
3320
+ break
3321
+ del sub_tensors
2546
3322
  else:
2547
3323
  if not ignore_dtype:
2548
3324
  dtype = p.data.dtype
@@ -2551,14 +3327,25 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2551
3327
  dtype = convertWeightsFloatTo if model_dtype == None else model_dtype
2552
3328
  if dtype != torch.float32:
2553
3329
  p.data = p.data.to(dtype)
2554
- if model_dtype== None:
2555
- model_dtype = dtype
3330
+ if model_dtype is None:
3331
+ if dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
3332
+ if fp8_fallback_dtype is None:
3333
+ fp8_fallback_dtype = dtype
3334
+ else:
3335
+ model_dtype = dtype
2556
3336
  else:
2557
3337
  if model_dtype != dtype:
2558
- pass
2559
- assert model_dtype == dtype
3338
+ if (
3339
+ dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
3340
+ or model_dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
3341
+ ):
3342
+ pass
3343
+ else:
3344
+ assert model_dtype == dtype
2560
3345
  current_model_size += torch.numel(p.data) * p.data.element_size()
2561
- current_model._dtype = model_dtype
3346
+ if model_dtype is None and fp8_fallback_dtype is not None:
3347
+ model_dtype = fp8_fallback_dtype
3348
+ current_model._dtype = model_dtype
2562
3349
  for b in current_model.buffers():
2563
3350
  # do not convert 32 bits float to 16 bits since buffers are few (and potential gain low) and usually they are needed for precision calculation (for instance Rope)
2564
3351
  current_model_size += torch.numel(b.data) * b.data.element_size()
@@ -2584,14 +3371,14 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2584
3371
  model_budget = new_budget if model_budget == 0 or new_budget < model_budget else model_budget
2585
3372
  if model_budget > 0 and model_budget > current_model_size:
2586
3373
  model_budget = 0
2587
- coef =0.8
3374
+ coef =vram_safety_coefficient
2588
3375
  if current_model_size > coef * self.device_mem_capacity and model_budget == 0 or model_budget > coef * self.device_mem_capacity:
2589
3376
  if verboseLevel >= 1:
2590
3377
  if model_budget == 0:
2591
- print(f"Model '{model_id}' is too large ({current_model_size/ONE_MB:0.1f} MB) to fit entirely in {coef * 100}% of the VRAM (max capacity is {coef * self.device_mem_capacity/ONE_MB}) MB)")
3378
+ print(f"Model '{model_id}' is too large ({current_model_size/ONE_MB:0.1f} MB) to fit entirely in {coef * 100:.0f}% of the VRAM (max capacity is {coef * self.device_mem_capacity/ONE_MB:0.1f}) MB)")
2592
3379
  else:
2593
3380
  print(f"Budget ({budget/ONE_MB:0.1f} MB) for Model '{model_id}' is too important so that this model can fit in the VRAM (max capacity is {self.device_mem_capacity/ONE_MB}) MB)")
2594
- print(f"Budget allocation for this model has been consequently reduced to the 80% of max GPU Memory ({coef * self.device_mem_capacity/ONE_MB:0.1f} MB). This may not leave enough working VRAM and you will probably need to define manually a lower budget for this model.")
3381
+ print(f"Budget allocation for this model has been consequently reduced to the {coef * 100:.0f}% of max GPU Memory ({coef * self.device_mem_capacity/ONE_MB:0.1f} MB). This may not leave enough working VRAM and you will probably need to define manually a lower budget for this model.")
2595
3382
  model_budget = coef * self.device_mem_capacity
2596
3383
 
2597
3384
 
@@ -2607,19 +3394,7 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2607
3394
  for model_id in models:
2608
3395
  current_model: torch.nn.Module = models[model_id]
2609
3396
  towers_names, towers_modules = _detect_main_towers(current_model)
2610
- # compile main iterative modules stacks ("towers")
2611
3397
  compilationInThisOne = compileAllModels or model_id in modelsToCompile
2612
- if compilationInThisOne:
2613
- if self.verboseLevel>=1:
2614
- if len(towers_modules)>0:
2615
- formated_tower_names = [name + '*' for name in towers_names]
2616
- print(f"Pytorch compilation of '{model_id}' is scheduled for these modules : {formated_tower_names}.")
2617
- else:
2618
- print(f"Pytorch compilation of model '{model_id}' is not yet supported.")
2619
-
2620
- for submodel in towers_modules:
2621
- submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode="default" ) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
2622
- #dynamic=True,
2623
3398
 
2624
3399
  if pinAllModels or model_id in modelsToPin:
2625
3400
  if hasattr(current_model,"_already_pinned"):
@@ -2627,6 +3402,7 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2627
3402
  print(f"Model '{model_id}' already pinned to reserved memory")
2628
3403
  else:
2629
3404
  _pin_to_memory(current_model, model_id, partialPinning= partialPinning, pinnedPEFTLora = pinnedPEFTLora, perc_reserved_mem_max = perc_reserved_mem_max, verboseLevel=verboseLevel)
3405
+
2630
3406
  current_budget = model_budgets[model_id]
2631
3407
  cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq, is_mod_seq = None, None, None, -1, False
2632
3408
  self.loaded_blocks[model_id] = None
@@ -2665,8 +3441,6 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2665
3441
  # print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
2666
3442
  top_submodule = len(submodule_name.split("."))==1
2667
3443
  offload_hooks = submodule._offload_hooks if hasattr(submodule, "_offload_hooks") else []
2668
- if len(offload_hooks) > 0:
2669
- pass
2670
3444
  assert top_submodule or len(offload_hooks) == 0, "custom offload hooks can only be set at the of the module"
2671
3445
  submodule_method_names = ["forward"] + offload_hooks
2672
3446
  for submodule_method_name in submodule_method_names:
@@ -2676,16 +3450,32 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2676
3450
  else:
2677
3451
  submodule_method = getattr(submodule, submodule_method_name)
2678
3452
  if callable(submodule_method):
2679
- if top_submodule and cur_blocks_name is None:
3453
+ if top_submodule and cur_blocks_name is None and not (any_lora and len(submodule._parameters)):
2680
3454
  self.hook_change_module(submodule, current_model, model_id, submodule_name, submodule_method, submodule_method_name)
2681
3455
  elif compilationInThisOne and submodule in towers_modules:
2682
3456
  self.hook_preload_blocks_for_compilation(submodule, model_id, cur_blocks_name, context = submodule_name )
2683
3457
  else:
2684
- self.hook_check_empty_cache_needed(submodule, current_model, model_id, cur_blocks_name, submodule_method, context = submodule_name )
2685
-
3458
+ if compilationInThisOne: #and False
3459
+ self.hook_check_load_into_GPU_if_needed(submodule, current_model, model_id, cur_blocks_name, submodule_method, context = submodule_name )
3460
+ else:
3461
+ self.hook_check_load_into_GPU_if_needed_default(submodule, current_model, model_id, cur_blocks_name, submodule_method, context = submodule_name )
3462
+
2686
3463
  self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name, submodule_name)
2687
3464
 
2688
3465
 
3466
+ # compile main iterative modules stacks ("towers")
3467
+ if compilationInThisOne:
3468
+ if self.verboseLevel>=1:
3469
+ if len(towers_modules)>0:
3470
+ formated_tower_names = [name + '*' for name in towers_names]
3471
+ print(f"Pytorch compilation of '{model_id}' is scheduled for these modules : {formated_tower_names}.")
3472
+ else:
3473
+ print(f"Pytorch compilation of model '{model_id}' is not yet supported.")
3474
+
3475
+ for submodel in towers_modules:
3476
+ submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode= compile_mode) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
3477
+ #dynamic=True,
3478
+
2689
3479
  self.tune_preloading(model_id, current_budget, towers_names)
2690
3480
  self.parameters_ref = {}
2691
3481