mmgp 3.3.0__py3-none-any.whl → 3.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/offload.py CHANGED
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.3.0 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.3.2 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -92,6 +92,8 @@ ONE_MB = 1048576
92
92
  sizeofbfloat16 = torch.bfloat16.itemsize
93
93
  sizeofint8 = torch.int8.itemsize
94
94
  total_pinned_bytes = 0
95
+ max_pinnable_bytes = 0
96
+
95
97
  physical_memory= psutil.virtual_memory().total
96
98
 
97
99
  HEADER = '\033[95m'
@@ -256,11 +258,11 @@ def _move_to_pinned_tensor(source_tensor, big_tensor, offset, length):
256
258
  assert t.is_pinned()
257
259
  return t
258
260
 
259
- def _safetensors_load_file(file_path):
261
+ def _safetensors_load_file(file_path, writable_tensors = True):
260
262
  from collections import OrderedDict
261
263
  sd = OrderedDict()
262
264
 
263
- with safetensors2.safe_open(file_path, framework="pt", device="cpu") as f:
265
+ with safetensors2.safe_open(file_path, framework="pt", device="cpu", writable_tensors =writable_tensors) as f:
264
266
  for k in f.keys():
265
267
  sd[k] = f.get_tensor(k)
266
268
  metadata = f.metadata()
@@ -319,6 +321,13 @@ def _extract_tie_weights_from_sd(sd , sd_name, verboseLevel =1):
319
321
  print(f"Found {tied_weights_count} tied weights for a total of {tied_weights_total/ONE_MB:0.2f} MB, last : {tied_weights_last}")
320
322
 
321
323
  def _pin_sd_to_memory(sd, sd_name, tied_weights = None, gig_tensor_size = BIG_TENSOR_MAX_SIZE, verboseLevel = 1):
324
+ global max_pinnable_bytes, total_pinned_bytes
325
+ if max_pinnable_bytes > 0 and max_pinnable_bytes >= max_pinnable_bytes:
326
+
327
+ if verboseLevel>=1 :
328
+ print(f"Unable pin data of '{sd_name}' to reserved RAM as there is no reserved RAM left")
329
+ return
330
+
322
331
  current_big_tensor_size = 0
323
332
  big_tensor_no = 0
324
333
  big_tensors_sizes = []
@@ -393,10 +402,19 @@ def _pin_sd_to_memory(sd, sd_name, tied_weights = None, gig_tensor_size = BIG_TE
393
402
 
394
403
 
395
404
  def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = True, gig_tensor_size = BIG_TENSOR_MAX_SIZE, verboseLevel = 1):
405
+
406
+ global max_pinnable_bytes, total_pinned_bytes
407
+ if max_pinnable_bytes > 0 and max_pinnable_bytes >= max_pinnable_bytes:
408
+
409
+ if verboseLevel>=1 :
410
+ print(f"Unable pin data of '{model_id}' to reserved RAM as there is no reserved RAM left")
411
+ return
412
+
396
413
  if partialPinning:
397
414
  towers_names, _ = _detect_main_towers(model)
398
415
 
399
416
 
417
+
400
418
  current_big_tensor_size = 0
401
419
  big_tensor_no = 0
402
420
  big_tensors_sizes = []
@@ -484,22 +502,27 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
484
502
  total = 0
485
503
 
486
504
 
505
+ failed_planned_allocation = False
487
506
 
488
- for size in big_tensors_sizes:
489
- try:
490
- current_big_tensor = torch.empty( size, dtype= torch.uint8, pin_memory=True, device="cpu")
491
- big_tensors.append(current_big_tensor)
492
- except:
493
- print(f"Unable to pin more tensors for this model as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
494
- break
507
+ # for size in big_tensors_sizes:
508
+ # try:
509
+ # # if total > 7000 * ONE_MB:
510
+ # # raise Exception ("test no more reserved RAM")
511
+ # current_big_tensor = torch.empty( size, dtype= torch.uint8, pin_memory=True, device="cpu")
512
+ # big_tensors.append(current_big_tensor)
513
+ # except:
514
+ # print(f"Unable to pin more tensors for this model as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
515
+ # max_pinnable_bytes = total + total_pinned_bytes
516
+ # failed_planned_allocation = True
517
+ # break
495
518
 
496
- last_big_tensor += 1
497
- total += size
519
+ # last_big_tensor += 1
520
+ # total += size
498
521
 
499
522
 
500
523
  gc.collect()
501
524
 
502
-
525
+ last_allocated_big_tensor = -1
503
526
  tensor_no = 0
504
527
  # prev_big_tensor = 0
505
528
  for n, (p, is_buffer) in params_dict.items():
@@ -520,46 +543,63 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
520
543
  assert p.data.is_pinned()
521
544
  q = None
522
545
  else:
546
+
523
547
  big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
548
+ if last_allocated_big_tensor < big_tensor_no:
549
+ last_allocated_big_tensor += 1
550
+ size = big_tensors_sizes[last_allocated_big_tensor]
551
+ try:
552
+ # if total > 7000 * ONE_MB:
553
+ # raise Exception ("test no more reserved RAM")
554
+ current_big_tensor = torch.empty( size, dtype= torch.uint8, pin_memory=True, device="cpu")
555
+ big_tensors.append(current_big_tensor)
556
+ except:
557
+ print(f"Unable to pin more tensors for this model as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
558
+ max_pinnable_bytes = total + total_pinned_bytes
559
+ failed_planned_allocation = True
560
+ break
561
+
562
+ total += size
563
+
524
564
  # if big_tensor_no != prev_big_tensor:
525
565
  # gc.collect()
526
566
  # prev_big_tensor = big_tensor_no
527
567
  # match_param, match_isbuffer = tied_weights.get(n, (None, False))
528
568
  # if match_param != None:
529
569
 
530
- if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
531
- current_big_tensor = big_tensors[big_tensor_no]
532
- if is_buffer :
533
- _force_load_buffer(p) # otherwise potential memory leak
534
- if isinstance(p, QTensor):
535
- if p._qtype == qint4:
536
- length1 = torch.numel(p._data._data) * p._data._data.element_size()
537
- p._data._data = _move_to_pinned_tensor(p._data._data, current_big_tensor, offset, length1)
538
- if hasattr(p,"_scale_shift"):
539
- length2 = torch.numel(p._scale_shift) * p._scale_shift.element_size()
540
- p._scale_shift = _move_to_pinned_tensor(p._scale_shift, current_big_tensor, offset + length1, length2)
541
- else:
542
- length2 = torch.numel(p._scale) * p._scale.element_size()
543
- p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
544
- length3 = torch.numel(p._shift) * p._shift.element_size()
545
- p._shift = _move_to_pinned_tensor(p._shift, current_big_tensor, offset + length1 + length2, length3)
570
+ # if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
571
+ current_big_tensor = big_tensors[big_tensor_no]
572
+ if is_buffer :
573
+ _force_load_buffer(p) # otherwise potential memory leak
574
+ if isinstance(p, QTensor):
575
+ if p._qtype == qint4:
576
+ length1 = torch.numel(p._data._data) * p._data._data.element_size()
577
+ p._data._data = _move_to_pinned_tensor(p._data._data, current_big_tensor, offset, length1)
578
+ if hasattr(p,"_scale_shift"):
579
+ length2 = torch.numel(p._scale_shift) * p._scale_shift.element_size()
580
+ p._scale_shift = _move_to_pinned_tensor(p._scale_shift, current_big_tensor, offset + length1, length2)
546
581
  else:
547
- length1 = torch.numel(p._data) * p._data.element_size()
548
- p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
549
582
  length2 = torch.numel(p._scale) * p._scale.element_size()
550
583
  p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
584
+ length3 = torch.numel(p._shift) * p._shift.element_size()
585
+ p._shift = _move_to_pinned_tensor(p._shift, current_big_tensor, offset + length1 + length2, length3)
551
586
  else:
552
- length = torch.numel(p.data) * p.data.element_size()
553
- p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
587
+ length1 = torch.numel(p._data) * p._data.element_size()
588
+ p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
589
+ length2 = torch.numel(p._scale) * p._scale.element_size()
590
+ p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
591
+ else:
592
+ length = torch.numel(p.data) * p.data.element_size()
593
+ p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
554
594
  tensor_no += 1
555
595
  del p
556
- global total_pinned_bytes
596
+ model._pinned_bytes = total
557
597
  total_pinned_bytes += total
558
598
  del params_dict
559
599
  gc.collect()
560
600
 
561
601
  if verboseLevel >=1:
562
- if partialPinning:
602
+ if partialPinning or failed_planned_allocation:
563
603
  print(f"The model was partially pinned to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
564
604
  else:
565
605
  print(f"The whole model was pinned to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
@@ -575,7 +615,7 @@ def _welcome():
575
615
  if welcome_displayed:
576
616
  return
577
617
  welcome_displayed = True
578
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.2.8) by DeepBeepMeep ************{ENDC}{UNBOLD}")
618
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.3.3) by DeepBeepMeep ************{ENDC}{UNBOLD}")
579
619
 
580
620
  def _extract_num_from_str(num_in_str):
581
621
  size = len(num_in_str)
@@ -1128,7 +1168,7 @@ def move_loras_to_device(model, device="cpu" ):
1128
1168
  if ".lora_" in k:
1129
1169
  m.to(device)
1130
1170
 
1131
- def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, modelClass=None, modelPrefix = None, verboseLevel = -1):
1171
+ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1):
1132
1172
  """
1133
1173
  quick version of .LoadfromPretrained of the transformers library
1134
1174
  used to build a model and load the corresponding weights (quantized or not)
@@ -1144,7 +1184,7 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizat
1144
1184
 
1145
1185
  verboseLevel = _compute_verbose_level(verboseLevel)
1146
1186
 
1147
- with safetensors2.safe_open(model_path) as f:
1187
+ with safetensors2.safe_open(model_path, writable_tensors =writable_tensors) as f:
1148
1188
  metadata = f.metadata()
1149
1189
 
1150
1190
  if metadata is None:
@@ -1208,13 +1248,13 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizat
1208
1248
 
1209
1249
  model._config = transformer_config
1210
1250
 
1211
- load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, verboseLevel=verboseLevel )
1251
+ load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors ,verboseLevel=verboseLevel )
1212
1252
 
1213
1253
  return model
1214
1254
 
1215
1255
 
1216
1256
 
1217
- def load_model_data(model, file_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, verboseLevel = -1):
1257
+ def load_model_data(model, file_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, verboseLevel = -1):
1218
1258
  """
1219
1259
  Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
1220
1260
  """
@@ -1252,7 +1292,7 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
1252
1292
  if "module" in state_dict:
1253
1293
  state_dict = state_dict["module"]
1254
1294
  else:
1255
- state_dict, metadata = _safetensors_load_file(file_path)
1295
+ state_dict, metadata = _safetensors_load_file(file_path, writable_tensors =writable_tensors)
1256
1296
 
1257
1297
  if metadata is None:
1258
1298
  quantization_map = None
@@ -1447,7 +1487,6 @@ class offload:
1447
1487
  def __init__(self):
1448
1488
  self.active_models = []
1449
1489
  self.active_models_ids = []
1450
- self.active_subcaches = {}
1451
1490
  self.models = {}
1452
1491
  self.cotenants_map = {
1453
1492
  "text_encoder": ["vae", "text_encoder_2"],
@@ -1709,7 +1748,6 @@ class offload:
1709
1748
 
1710
1749
  self.active_models = []
1711
1750
  self.active_models_ids = []
1712
- self.active_subcaches = []
1713
1751
  torch.cuda.empty_cache()
1714
1752
  gc.collect()
1715
1753
  self.last_reserved_mem_check = time.time()
@@ -2022,24 +2060,29 @@ class offload:
2022
2060
  print(f"Async loading plan for model '{model_id}' : {(preload_total+base_size)/ONE_MB:0.2f} MB will be preloaded (base size of {base_size/ONE_MB:0.2f} MB + {preload_total/total_size*100:0.1f}% of recurrent layers data) with a {max_blocks_fetch/ONE_MB:0.2f} MB async" + (" circular" if len(towers) == 1 else "") + " shuttle")
2023
2061
 
2024
2062
  def release(self):
2025
- global last_offload_obj
2063
+ global last_offload_obj, total_pinned_bytes
2026
2064
 
2027
2065
  if last_offload_obj == self:
2028
2066
  last_offload_obj = None
2029
2067
 
2030
2068
  self.unload_all()
2031
- self.default_stream = None
2069
+ self.active_models = None
2070
+ self.default_stream = None
2071
+ self.transfer_stream = None
2072
+ self.parameters_ref = None
2032
2073
  keys= [k for k in self.blocks_of_modules.keys()]
2033
2074
  for k in keys:
2034
2075
  del self.blocks_of_modules[k]
2035
2076
 
2036
2077
  self.blocks_of_modules = None
2037
2078
 
2038
-
2039
2079
  for model_id, model in self.models.items():
2040
2080
  move_loras_to_device(model, "cpu")
2081
+ if hasattr(model, "_pinned_bytes"):
2082
+ total_pinned_bytes -= model._pinned_bytes
2041
2083
  if hasattr(model, "_loras_model_data"):
2042
2084
  unload_loras_from_model(model)
2085
+ model = None
2043
2086
 
2044
2087
  self.models = None
2045
2088
 
@@ -2049,7 +2092,7 @@ class offload:
2049
2092
 
2050
2093
 
2051
2094
 
2052
- def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, loras = None, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, workingVRAM = None, asyncTransfers = True, compile = False, perc_reserved_mem_max = 0, coTenantsMap = None, verboseLevel = -1):
2095
+ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, partialPinning = False, loras = None, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, workingVRAM = None, asyncTransfers = True, compile = False, convertFloatToBfloat16 = True, perc_reserved_mem_max = 0, coTenantsMap = None, verboseLevel = -1):
2053
2096
  """Hook to a pipeline or a group of modules in order to reduce their VRAM requirements:
2054
2097
  pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
2055
2098
  quantizeTransformer: set True by default will quantize on the fly the video / image model
@@ -2156,7 +2199,6 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, l
2156
2199
  modelPinned = (pinAllModels or model_id in modelsToPin) and not hasattr(current_model,"_already_pinned")
2157
2200
 
2158
2201
  current_model_size = 0
2159
-
2160
2202
  for n, p in current_model.named_parameters():
2161
2203
  p.requires_grad = False
2162
2204
  if isinstance(p, QTensor):
@@ -2176,7 +2218,7 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, l
2176
2218
  current_model_size += torch.numel(p._data) * p._data.element_size()
2177
2219
 
2178
2220
  else:
2179
- if p.data.dtype == torch.float32:
2221
+ if convertFloatToBfloat16 and p.data.dtype == torch.float32:
2180
2222
  # convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
2181
2223
  p.data = p.data.to(torch.bfloat16)
2182
2224
  current_model_size += torch.numel(p.data) * p.data.element_size()
@@ -2219,9 +2261,8 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, l
2219
2261
 
2220
2262
  model_budgets[model_id] = model_budget
2221
2263
 
2222
- partialPinning = False
2223
2264
 
2224
- if estimatesBytesToPin > 0 and estimatesBytesToPin >= (max_reservable_memory - total_pinned_bytes):
2265
+ if not partialPinning and estimatesBytesToPin > 0 and estimatesBytesToPin >= (max_reservable_memory - total_pinned_bytes):
2225
2266
  if self.verboseLevel >=1:
2226
2267
  print(f"Switching to partial pinning since full requirements for pinned models is {estimatesBytesToPin/ONE_MB:0.1f} MB while estimated available reservable RAM is {(max_reservable_memory-total_pinned_bytes)/ONE_MB:0.1f} MB. You may increase the value of parameter 'perc_reserved_mem_max' to a value higher than {perc_reserved_mem_max:0.2f} to force full pinnning." )
2227
2268
  partialPinning = True
mmgp/safetensors2.py CHANGED
@@ -1,4 +1,4 @@
1
- # ------------------ Safetensors2 1.0 by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Safetensors2 1.1 by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module entirely written in Python is a replacement for the safetensor library which requires much less RAM to load models.
4
4
  # It can be conveniently used to keep a low RAM consumption when handling transit data (for instance when quantizing or transferring tensors to reserver RAM)
@@ -16,12 +16,14 @@ import safetensors
16
16
  import accelerate
17
17
  import os
18
18
  from collections import OrderedDict
19
+ import warnings
19
20
 
21
+ warnings.filterwarnings("ignore", ".*The given buffer is not writable, and PyTorch does not support non-writable tensors*")
20
22
 
21
23
  _old_torch_load_file = None
22
24
  _old_safe_open = None
23
25
 
24
-
26
+ all_tensors_are_read_only = False
25
27
 
26
28
  mmm = {}
27
29
  verboseLevel = 1
@@ -232,7 +234,7 @@ def torch_write_file(sd, file_path, quantization_map = None, config = None, extr
232
234
  class SafeTensorFile:
233
235
  """Main class for accessing safetensors files that provides memory-efficient access"""
234
236
 
235
- def __init__(self, file_path, metadata, catalog, skip_bytes, lazy_loading = True):
237
+ def __init__(self, file_path, metadata, catalog, skip_bytes, lazy_loading = True, writable_tensors = True):
236
238
  self._file_path = file_path
237
239
  self._metadata = metadata
238
240
  self._catalog = catalog
@@ -241,19 +243,20 @@ class SafeTensorFile:
241
243
  self.sd = None
242
244
  self.mtracker = None
243
245
  self.lazy_loading = lazy_loading
246
+ self.writable_tensors = writable_tensors
244
247
 
245
248
  @classmethod
246
- def load_metadata(cls, file_path, lazy_loading = True):
249
+ def load_metadata(cls, file_path, lazy_loading = True, writable_tensors = True):
247
250
  with open(file_path, 'rb') as f:
248
251
  catalog, metadata, skip_bytes = _read_safetensors_header(file_path, f)
249
252
 
250
- return cls(file_path, metadata, catalog, skip_bytes, lazy_loading)
253
+ return cls(file_path, metadata, catalog, skip_bytes, lazy_loading, writable_tensors )
251
254
 
252
- def init_tensors(self, lazyTensors = True):
255
+ def init_tensors(self, lazyTensors = True, writable_tensors = True):
253
256
  if self.sd is None:
254
257
  self.lazy_loading = lazyTensors
255
258
  if lazyTensors:
256
- self.sd = self.create_tensors_with_mmap()
259
+ self.sd = self.create_tensors_with_mmap(writable_tensors)
257
260
  else:
258
261
  self.sd = self.create_tensors_without_mmap()
259
262
  # else:
@@ -263,7 +266,7 @@ class SafeTensorFile:
263
266
  return self.sd
264
267
 
265
268
 
266
- def create_tensors_with_mmap(self):
269
+ def create_tensors_with_mmap(self, writable_tensors = True):
267
270
 
268
271
  self.mtracker = MmapTracker(self._file_path)
269
272
  import mmap
@@ -302,7 +305,7 @@ class SafeTensorFile:
302
305
  with open(self._file_path, 'rb') as f:
303
306
  i = 0
304
307
  for map_start, map_size in maps_info:
305
- mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access=mmap.ACCESS_COPY) #.ACCESS_READ
308
+ mm = mmap.mmap(f.fileno(), map_size, offset=map_start, access= mmap.ACCESS_COPY if writable_tensors else mmap.ACCESS_READ)
306
309
  maps.append((mm, map_start, map_size))
307
310
  self.mtracker.register(mm, i, map_start, map_size)
308
311
  i = i+ 1
@@ -359,7 +362,7 @@ class SafeTensorFile:
359
362
  def get_tensor(self, name: str) -> torch.tensor:
360
363
  """Get a tensor by name"""
361
364
  # To do : switch to a JIT tensor creation per tensor
362
- self.init_tensors()
365
+ self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
363
366
  return self.sd[name]
364
367
 
365
368
  def keys(self) -> List[str]:
@@ -374,7 +377,7 @@ class SafeTensorFile:
374
377
 
375
378
  def tensors(self) -> Dict[str, torch.tensor]:
376
379
  """Get dictionary of all tensors"""
377
- self.init_tensors(self.lazy_loading)
380
+ self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
378
381
  return self.sd
379
382
 
380
383
  def metadata(self) -> Optional[Dict[str, str]]:
@@ -383,7 +386,7 @@ class SafeTensorFile:
383
386
 
384
387
  def __len__(self) -> int:
385
388
  """Get number of tensors"""
386
- self.init_tensors(self.lazy_loading)
389
+ self.init_tensors(self.lazy_loading, writable_tensors= self.writable_tensors)
387
390
  return len(self.keys())
388
391
 
389
392
  def __contains__(self, key: str) -> bool:
@@ -401,17 +404,22 @@ class SafeTensorFile:
401
404
  class _SafeTensorLoader:
402
405
  """Context manager for loading SafeTensorFile"""
403
406
 
404
- def __init__(self, filename: str ):
407
+ def __init__(self, filename: str, writable_tensors = True ):
405
408
  self.filename = Path(filename)
409
+ self.writable_tensors = writable_tensors
406
410
  self.sft = None
407
411
  if not self.filename.exists():
408
412
  raise FileNotFoundError(f"File not found: {filename}")
409
413
 
410
414
  def __enter__(self) -> SafeTensorFile:
411
415
  """Open file and return SafeTensorFile instance"""
412
-
416
+ writable_tensors = self.writable_tensors
417
+
418
+ if all_tensors_are_read_only:
419
+ writable_tensors = False
420
+
413
421
  try:
414
- self.sft = SafeTensorFile.load_metadata(self.filename)
422
+ self.sft = SafeTensorFile.load_metadata(self.filename, writable_tensors= writable_tensors)
415
423
  return self.sft
416
424
 
417
425
  except Exception as e:
@@ -428,14 +436,14 @@ class _SafeTensorLoader:
428
436
  pass
429
437
 
430
438
 
431
- def safe_open(filename: str, framework: str = "pt",device = "cpu") -> _SafeTensorLoader:
439
+ def safe_open(filename: str, framework: str = "pt",device = "cpu", writable_tensors = True) -> _SafeTensorLoader:
432
440
  if device != "cpu" or framework !="pt":
433
441
  return _old_safe_open(filename =filename, framework=framework, device=device)
434
- return _SafeTensorLoader(filename)
442
+ return _SafeTensorLoader(filename, writable_tensors = writable_tensors)
435
443
 
436
- def torch_load_file( filename, device = 'cpu' ) -> Dict[str, torch.Tensor]:
444
+ def torch_load_file( filename, device = 'cpu', writable_tensors = True) -> Dict[str, torch.Tensor]:
437
445
  sd = {}
438
- with safe_open(filename, framework="pt", device = device ) as f:
446
+ with safe_open(filename, framework="pt", device = device, writable_tensors =writable_tensors ) as f:
439
447
  for k in f.keys():
440
448
  sd[k] = f.get_tensor(k)
441
449
  return sd
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mmgp
3
- Version: 3.3.0
3
+ Version: 3.3.2
4
4
  Summary: Memory Management for the GPU Poor
5
5
  Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -17,7 +17,7 @@ Dynamic: license-file
17
17
 
18
18
 
19
19
  <p align="center">
20
- <H2>Memory Management 3.3.0 for the GPU Poor by DeepBeepMeep</H2>
20
+ <H2>Memory Management 3.3.2 for the GPU Poor by DeepBeepMeep</H2>
21
21
  </p>
22
22
 
23
23
 
@@ -0,0 +1,9 @@
1
+ __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
3
+ mmgp/offload.py,sha256=43FnFfWqwhh2qz0uykqEpxb_XP9Jx8MPGzN31PExT2w,107470
4
+ mmgp/safetensors2.py,sha256=rmUbBmK3Dra5prUTTRSVi6-XUFAa9Mj6B5CNPgzt9To,17333
5
+ mmgp-3.3.2.dist-info/licenses/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
6
+ mmgp-3.3.2.dist-info/METADATA,sha256=mVMLkutqhUihIeo8uo_LK71ithm84_AEaNvnyRnzmEA,16153
7
+ mmgp-3.3.2.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
8
+ mmgp-3.3.2.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
9
+ mmgp-3.3.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (77.0.3)
2
+ Generator: setuptools (78.0.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,9 +0,0 @@
1
- __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
3
- mmgp/offload.py,sha256=xdlYbB8nKUywAAMPcfCzJmCxYHvBB5vcZgv2wEQTtbE,105329
4
- mmgp/safetensors2.py,sha256=DCdlRH3769CTyraAmWAB3b0XrVua7z6ygQ-OyKgJN6A,16453
5
- mmgp-3.3.0.dist-info/licenses/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
6
- mmgp-3.3.0.dist-info/METADATA,sha256=33eB_YmC6PciTkzi_Z_gsWWzoz6RJgyLbEItFatVghk,16153
7
- mmgp-3.3.0.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
8
- mmgp-3.3.0.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
9
- mmgp-3.3.0.dist-info/RECORD,,