unifiedefficientloader 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unifiedefficientloader/__init__.py +36 -0
- unifiedefficientloader/memory_efficient_loader.py +361 -0
- unifiedefficientloader/pinned_transfer.py +89 -0
- unifiedefficientloader/tensor_utils.py +54 -0
- unifiedefficientloader-0.2.0.dist-info/METADATA +132 -0
- unifiedefficientloader-0.2.0.dist-info/RECORD +9 -0
- unifiedefficientloader-0.2.0.dist-info/WHEEL +5 -0
- unifiedefficientloader-0.2.0.dist-info/licenses/LICENSE +21 -0
- unifiedefficientloader-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
|
|
3
|
+
def check_dependencies(*packages):
|
|
4
|
+
"""
|
|
5
|
+
Check if required packages are installed.
|
|
6
|
+
Throws a descriptive error if not.
|
|
7
|
+
"""
|
|
8
|
+
missing = []
|
|
9
|
+
for pkg in packages:
|
|
10
|
+
if importlib.util.find_spec(pkg) is None:
|
|
11
|
+
missing.append(pkg)
|
|
12
|
+
|
|
13
|
+
if missing:
|
|
14
|
+
missing_str = ", ".join(missing)
|
|
15
|
+
raise ImportError(
|
|
16
|
+
f"Missing required packages for unifiedefficientloader: {missing_str}. "
|
|
17
|
+
f"Please install them using: pip install {missing_str}"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
# Pre-check torch as it is the foundation of most of these tools
|
|
21
|
+
check_dependencies("torch")
|
|
22
|
+
|
|
23
|
+
from .memory_efficient_loader import UnifiedSafetensorsLoader, MemoryEfficientSafeOpen
|
|
24
|
+
from .tensor_utils import dict_to_tensor, tensor_to_dict
|
|
25
|
+
from .pinned_transfer import transfer_to_gpu_pinned, set_verbose, get_pinned_transfer_stats, reset_pinned_transfer_stats
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"UnifiedSafetensorsLoader",
|
|
29
|
+
"MemoryEfficientSafeOpen",
|
|
30
|
+
"dict_to_tensor",
|
|
31
|
+
"tensor_to_dict",
|
|
32
|
+
"transfer_to_gpu_pinned",
|
|
33
|
+
"set_verbose",
|
|
34
|
+
"get_pinned_transfer_stats",
|
|
35
|
+
"reset_pinned_transfer_stats",
|
|
36
|
+
]
|
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified safetensors loader with optional memory-efficient mode.
|
|
3
|
+
|
|
4
|
+
Provides a consistent interface for tensor loading regardless of mode.
|
|
5
|
+
Requires `torch`, `safetensors`, and optionally `tqdm`.
|
|
6
|
+
"""
|
|
7
|
+
import gc
|
|
8
|
+
import json
|
|
9
|
+
import struct
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Dict, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
def _ensure_torch():
|
|
16
|
+
try:
|
|
17
|
+
import torch
|
|
18
|
+
return torch
|
|
19
|
+
except ImportError:
|
|
20
|
+
raise ImportError("The 'torch' package is required but not installed. Please install it.")
|
|
21
|
+
|
|
22
|
+
def _ensure_safetensors():
|
|
23
|
+
try:
|
|
24
|
+
import safetensors
|
|
25
|
+
from safetensors import safe_open
|
|
26
|
+
return safe_open
|
|
27
|
+
except ImportError:
|
|
28
|
+
raise ImportError("The 'safetensors' package is required but not installed. Please install it.")
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import torch
|
|
32
|
+
except ImportError:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
class UnifiedSafetensorsLoader:
|
|
36
|
+
"""Unified safetensors loader supporting both preload and streaming modes.
|
|
37
|
+
|
|
38
|
+
In standard mode (low_memory=False):
|
|
39
|
+
- Loads all tensors upfront (fast, uses more RAM)
|
|
40
|
+
- Tensors remain in memory until explicitly deleted
|
|
41
|
+
|
|
42
|
+
In low-memory mode (low_memory=True):
|
|
43
|
+
- Loads tensors on-demand via get_tensor()
|
|
44
|
+
- Caller should delete tensors after processing
|
|
45
|
+
|
|
46
|
+
Usage:
|
|
47
|
+
with UnifiedSafetensorsLoader("model.safetensors", low_memory=True) as loader:
|
|
48
|
+
for key in loader.keys():
|
|
49
|
+
tensor = loader.get_tensor(key)
|
|
50
|
+
# ... process tensor ...
|
|
51
|
+
loader.mark_processed(key) # Frees memory in low_memory mode
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, filename: str, low_memory: bool = False):
|
|
55
|
+
"""Initialize the loader.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
filename: Path to safetensors file
|
|
59
|
+
low_memory: If True, use streaming mode; if False, preload all tensors
|
|
60
|
+
"""
|
|
61
|
+
torch = _ensure_torch()
|
|
62
|
+
safe_open = _ensure_safetensors()
|
|
63
|
+
|
|
64
|
+
self.filename = filename
|
|
65
|
+
self.low_memory = low_memory
|
|
66
|
+
self._tensors: Dict[str, 'torch.Tensor'] = {}
|
|
67
|
+
self._all_keys = []
|
|
68
|
+
self._file = None
|
|
69
|
+
self._header = None
|
|
70
|
+
self._header_size = None
|
|
71
|
+
self._metadata: Dict[str, str] = {}
|
|
72
|
+
|
|
73
|
+
if low_memory:
|
|
74
|
+
# Streaming mode: read header only
|
|
75
|
+
self._header, self._header_size = self._read_header()
|
|
76
|
+
self._file = None # Opened lazily to support multiprocessing DataLoader
|
|
77
|
+
self._all_keys = [k for k in self._header.keys() if k != "__metadata__"]
|
|
78
|
+
# Extract metadata from header (safetensors stores it under __metadata__ key)
|
|
79
|
+
self._metadata = self._header.get("__metadata__", {})
|
|
80
|
+
logger.debug(f"Initialized Low-memory mode: parsed header of size {self._header_size} bytes.")
|
|
81
|
+
logger.debug(f"Found {len(self._all_keys)} tensors (streaming mode)")
|
|
82
|
+
else:
|
|
83
|
+
# Standard mode: preload all tensors
|
|
84
|
+
with safe_open(filename, framework="pt", device="cpu") as f:
|
|
85
|
+
self._metadata = f.metadata() or {}
|
|
86
|
+
self._all_keys = list(f.keys())
|
|
87
|
+
print(f"Loading {len(self._all_keys)} tensors from source file...")
|
|
88
|
+
try:
|
|
89
|
+
from tqdm import tqdm
|
|
90
|
+
iterator = tqdm(self._all_keys, desc="Loading tensors")
|
|
91
|
+
except ImportError:
|
|
92
|
+
iterator = self._all_keys
|
|
93
|
+
|
|
94
|
+
for key in iterator:
|
|
95
|
+
self._tensors[key] = f.get_tensor(key)
|
|
96
|
+
|
|
97
|
+
def __enter__(self):
|
|
98
|
+
return self
|
|
99
|
+
|
|
100
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
101
|
+
self.close()
|
|
102
|
+
|
|
103
|
+
def __getstate__(self):
|
|
104
|
+
"""Make loader picklable for multiprocessing DataLoaders."""
|
|
105
|
+
state = self.__dict__.copy()
|
|
106
|
+
state['_file'] = None
|
|
107
|
+
return state
|
|
108
|
+
|
|
109
|
+
def __setstate__(self, state):
|
|
110
|
+
self.__dict__.update(state)
|
|
111
|
+
|
|
112
|
+
def close(self):
|
|
113
|
+
"""Close file handle and release resources."""
|
|
114
|
+
if self._file:
|
|
115
|
+
self._file.close()
|
|
116
|
+
self._file = None
|
|
117
|
+
self._tensors.clear()
|
|
118
|
+
|
|
119
|
+
def keys(self):
|
|
120
|
+
"""Return list of all tensor keys."""
|
|
121
|
+
return self._all_keys
|
|
122
|
+
|
|
123
|
+
def metadata(self) -> Dict[str, str]:
|
|
124
|
+
"""Return file metadata."""
|
|
125
|
+
return self._metadata
|
|
126
|
+
|
|
127
|
+
def get_shape(self, key: str) -> tuple:
|
|
128
|
+
"""Get tensor shape without loading tensor data.
|
|
129
|
+
|
|
130
|
+
In low-memory mode, reads from header.
|
|
131
|
+
In standard mode, returns shape from loaded tensor.
|
|
132
|
+
"""
|
|
133
|
+
if self.low_memory:
|
|
134
|
+
if key not in self._header:
|
|
135
|
+
raise KeyError(f"Tensor '{key}' not found in file")
|
|
136
|
+
return tuple(self._header[key]["shape"])
|
|
137
|
+
else:
|
|
138
|
+
return tuple(self._tensors[key].shape)
|
|
139
|
+
|
|
140
|
+
def get_ndim(self, key: str) -> int:
|
|
141
|
+
"""Get tensor ndim without loading tensor data."""
|
|
142
|
+
return len(self.get_shape(key))
|
|
143
|
+
|
|
144
|
+
def get_tensor(self, key: str) -> 'torch.Tensor':
|
|
145
|
+
"""Get a tensor by key.
|
|
146
|
+
|
|
147
|
+
In standard mode, returns from cache.
|
|
148
|
+
In low-memory mode, loads from file on-demand.
|
|
149
|
+
"""
|
|
150
|
+
if not self.low_memory:
|
|
151
|
+
# Standard mode: return from preloaded cache
|
|
152
|
+
return self._tensors[key]
|
|
153
|
+
|
|
154
|
+
# Low-memory mode: load on-demand
|
|
155
|
+
if key not in self._header:
|
|
156
|
+
raise KeyError(f"Tensor '{key}' not found in file")
|
|
157
|
+
|
|
158
|
+
if self._file is None:
|
|
159
|
+
self._file = open(self.filename, "rb")
|
|
160
|
+
|
|
161
|
+
metadata = self._header[key]
|
|
162
|
+
offset_start, offset_end = metadata["data_offsets"]
|
|
163
|
+
|
|
164
|
+
if offset_start != offset_end:
|
|
165
|
+
logger.debug(f"Loading tensor '{key}' from offset {offset_start} to {offset_end} ({(offset_end - offset_start)} bytes)")
|
|
166
|
+
self._file.seek(self._header_size + 8 + offset_start)
|
|
167
|
+
# Use bytearray to create a writable buffer, avoiding PyTorch warning
|
|
168
|
+
# about non-writable tensors from read-only bytes.
|
|
169
|
+
tensor_bytes = bytearray(offset_end - offset_start)
|
|
170
|
+
self._file.readinto(tensor_bytes)
|
|
171
|
+
else:
|
|
172
|
+
tensor_bytes = None
|
|
173
|
+
|
|
174
|
+
return self._deserialize_tensor(tensor_bytes, metadata)
|
|
175
|
+
|
|
176
|
+
def mark_processed(self, key: str):
|
|
177
|
+
"""Mark a tensor as processed, freeing memory if in low-memory mode.
|
|
178
|
+
|
|
179
|
+
In standard mode, optionally deletes from cache.
|
|
180
|
+
In low-memory mode, this is a no-op (tensor was never cached).
|
|
181
|
+
"""
|
|
182
|
+
if not self.low_memory and key in self._tensors:
|
|
183
|
+
del self._tensors[key]
|
|
184
|
+
gc.collect()
|
|
185
|
+
|
|
186
|
+
def _read_header(self):
|
|
187
|
+
"""Read and parse the safetensors header."""
|
|
188
|
+
with open(self.filename, "rb") as f:
|
|
189
|
+
header_size = struct.unpack("<Q", f.read(8))[0]
|
|
190
|
+
header_json = f.read(header_size).decode("utf-8")
|
|
191
|
+
return json.loads(header_json), header_size
|
|
192
|
+
|
|
193
|
+
def _deserialize_tensor(self, tensor_bytes, metadata):
|
|
194
|
+
"""Deserialize raw bytes into a torch tensor."""
|
|
195
|
+
torch = _ensure_torch()
|
|
196
|
+
dtype_str = metadata["dtype"]
|
|
197
|
+
shape = metadata["shape"]
|
|
198
|
+
dtype = self._get_torch_dtype(dtype_str)
|
|
199
|
+
|
|
200
|
+
if tensor_bytes is None:
|
|
201
|
+
byte_tensor = torch.empty(0, dtype=torch.uint8)
|
|
202
|
+
else:
|
|
203
|
+
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
|
|
204
|
+
|
|
205
|
+
if dtype_str in ["F8_E5M2", "F8_E4M3"]:
|
|
206
|
+
return self._convert_float8(byte_tensor, dtype_str, shape)
|
|
207
|
+
|
|
208
|
+
return byte_tensor.view(dtype).reshape(shape)
|
|
209
|
+
|
|
210
|
+
@staticmethod
|
|
211
|
+
def _get_torch_dtype(dtype_str: str):
|
|
212
|
+
"""Map safetensors dtype string to torch dtype."""
|
|
213
|
+
torch = _ensure_torch()
|
|
214
|
+
dtype_map = {
|
|
215
|
+
"F64": torch.float64,
|
|
216
|
+
"F32": torch.float32,
|
|
217
|
+
"F16": torch.float16,
|
|
218
|
+
"BF16": torch.bfloat16,
|
|
219
|
+
"I64": torch.int64,
|
|
220
|
+
"I32": torch.int32,
|
|
221
|
+
"I16": torch.int16,
|
|
222
|
+
"I8": torch.int8,
|
|
223
|
+
"U8": torch.uint8,
|
|
224
|
+
"BOOL": torch.bool,
|
|
225
|
+
}
|
|
226
|
+
if hasattr(torch, "float8_e5m2"):
|
|
227
|
+
dtype_map["F8_E5M2"] = torch.float8_e5m2
|
|
228
|
+
if hasattr(torch, "float8_e4m3fn"):
|
|
229
|
+
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
|
|
230
|
+
|
|
231
|
+
dtype = dtype_map.get(dtype_str)
|
|
232
|
+
if dtype is None:
|
|
233
|
+
raise ValueError(f"Unsupported dtype: {dtype_str}")
|
|
234
|
+
return dtype
|
|
235
|
+
|
|
236
|
+
@staticmethod
|
|
237
|
+
def _convert_float8(byte_tensor, dtype_str: str, shape: list):
|
|
238
|
+
"""Convert bytes to float8 tensor."""
|
|
239
|
+
torch = _ensure_torch()
|
|
240
|
+
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
|
|
241
|
+
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
|
|
242
|
+
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
|
|
243
|
+
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
|
|
244
|
+
else:
|
|
245
|
+
raise ValueError(f"Unsupported float8 type: {dtype_str}")
|
|
246
|
+
|
|
247
|
+
def async_stream(self, keys: list, batch_size: int = 1, prefetch_batches: int = 2, pin_memory: bool = False):
|
|
248
|
+
"""Asynchronously stream tensors from disk.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
keys: List of tensor keys to load
|
|
252
|
+
batch_size: Number of tensors to yield in each batch
|
|
253
|
+
prefetch_batches: Number of batches to pre-fetch in background
|
|
254
|
+
pin_memory: If True, tensors will be pinned in CPU memory (sequentially in main thread)
|
|
255
|
+
|
|
256
|
+
Yields:
|
|
257
|
+
List of (key, tensor) tuples
|
|
258
|
+
"""
|
|
259
|
+
import threading
|
|
260
|
+
import queue
|
|
261
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
262
|
+
|
|
263
|
+
torch = _ensure_torch()
|
|
264
|
+
thread_local = threading.local()
|
|
265
|
+
|
|
266
|
+
def get_file_handle():
|
|
267
|
+
if not hasattr(thread_local, 'file'):
|
|
268
|
+
thread_local.file = open(self.filename, "rb")
|
|
269
|
+
return thread_local.file
|
|
270
|
+
|
|
271
|
+
def _worker_load(key):
|
|
272
|
+
try:
|
|
273
|
+
# Direct thread-safe read
|
|
274
|
+
metadata = self._header[key]
|
|
275
|
+
offset_start, offset_end = metadata["data_offsets"]
|
|
276
|
+
if offset_start != offset_end:
|
|
277
|
+
f = get_file_handle()
|
|
278
|
+
f.seek(self._header_size + 8 + offset_start)
|
|
279
|
+
tensor_bytes = bytearray(offset_end - offset_start)
|
|
280
|
+
f.readinto(tensor_bytes)
|
|
281
|
+
else:
|
|
282
|
+
tensor_bytes = None
|
|
283
|
+
|
|
284
|
+
tensor = self._deserialize_tensor(tensor_bytes, metadata)
|
|
285
|
+
return key, tensor, None
|
|
286
|
+
except Exception as e:
|
|
287
|
+
# Fallback info for main thread
|
|
288
|
+
return key, None, e
|
|
289
|
+
|
|
290
|
+
# Queue for individual (key, tensor) pairs
|
|
291
|
+
# Size it to hold enough for prefetch_batches
|
|
292
|
+
q = queue.Queue(maxsize=prefetch_batches * batch_size)
|
|
293
|
+
|
|
294
|
+
def _producer():
|
|
295
|
+
# Use a reasonable number of workers for I/O bound tasks
|
|
296
|
+
max_workers = min(16, max(4, batch_size))
|
|
297
|
+
# Limit task submission to maintain backpressure on memory
|
|
298
|
+
max_in_flight = max(max_workers, prefetch_batches * batch_size)
|
|
299
|
+
|
|
300
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
301
|
+
futures = []
|
|
302
|
+
key_iter = iter(keys)
|
|
303
|
+
|
|
304
|
+
# Fill the pipeline
|
|
305
|
+
for _ in range(max_in_flight):
|
|
306
|
+
try:
|
|
307
|
+
k = next(key_iter)
|
|
308
|
+
futures.append(executor.submit(_worker_load, k))
|
|
309
|
+
except StopIteration:
|
|
310
|
+
break
|
|
311
|
+
|
|
312
|
+
while futures:
|
|
313
|
+
# Maintain order by taking the first future
|
|
314
|
+
f = futures.pop(0)
|
|
315
|
+
result = f.result() # Blocks until this specific tensor is loaded
|
|
316
|
+
q.put(result) # Blocks if the consumption queue is full
|
|
317
|
+
|
|
318
|
+
# Submit next task if available
|
|
319
|
+
try:
|
|
320
|
+
k = next(key_iter)
|
|
321
|
+
futures.append(executor.submit(_worker_load, k))
|
|
322
|
+
except StopIteration:
|
|
323
|
+
pass
|
|
324
|
+
|
|
325
|
+
q.put(None) # Sentinel
|
|
326
|
+
|
|
327
|
+
producer_thread = threading.Thread(target=_producer, daemon=True)
|
|
328
|
+
producer_thread.start()
|
|
329
|
+
|
|
330
|
+
batch = []
|
|
331
|
+
while True:
|
|
332
|
+
res = q.get()
|
|
333
|
+
if res is None:
|
|
334
|
+
if batch:
|
|
335
|
+
yield batch
|
|
336
|
+
break
|
|
337
|
+
|
|
338
|
+
k, t, err = res
|
|
339
|
+
if err is not None:
|
|
340
|
+
logger.warning(f"Async load failed for {k}, falling back to sync: {err}")
|
|
341
|
+
# Fallback synchronous load
|
|
342
|
+
try:
|
|
343
|
+
t = self.get_tensor(k)
|
|
344
|
+
except Exception as sync_err:
|
|
345
|
+
logger.error(f"Sync fallback also failed for {k}: {sync_err}")
|
|
346
|
+
raise sync_err
|
|
347
|
+
|
|
348
|
+
# Pin memory sequentially in the main thread to avoid OS-level lock contention
|
|
349
|
+
if pin_memory and t.device.type == 'cpu':
|
|
350
|
+
try:
|
|
351
|
+
t = t.pin_memory()
|
|
352
|
+
except Exception as e:
|
|
353
|
+
logger.warning(f"Failed to pin memory for {k}: {e}")
|
|
354
|
+
|
|
355
|
+
batch.append((k, t))
|
|
356
|
+
if len(batch) == batch_size:
|
|
357
|
+
yield batch
|
|
358
|
+
batch = []
|
|
359
|
+
|
|
360
|
+
# Backward compatibility alias
|
|
361
|
+
MemoryEfficientSafeOpen = UnifiedSafetensorsLoader
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pinned memory utilities for faster CPU→GPU tensor transfers.
|
|
3
|
+
|
|
4
|
+
Pinned (page-locked) memory enables faster DMA transfers to GPU.
|
|
5
|
+
Uses PyTorch's native pin_memory() with non_blocking transfers.
|
|
6
|
+
"""
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
def _ensure_torch():
|
|
13
|
+
try:
|
|
14
|
+
import torch
|
|
15
|
+
return torch
|
|
16
|
+
except ImportError:
|
|
17
|
+
raise ImportError("The 'torch' package is required but not installed. Please install it.")
|
|
18
|
+
|
|
19
|
+
# Module-level configuration
|
|
20
|
+
_verbose = False
|
|
21
|
+
_pinned_transfer_stats = {"pinned": 0, "fallback": 0}
|
|
22
|
+
|
|
23
|
+
def set_verbose(enabled: bool):
|
|
24
|
+
"""Enable/disable verbose output for pinned transfers."""
|
|
25
|
+
global _verbose
|
|
26
|
+
_verbose = enabled
|
|
27
|
+
|
|
28
|
+
def get_pinned_transfer_stats():
|
|
29
|
+
"""Return pinned transfer statistics for verification."""
|
|
30
|
+
return _pinned_transfer_stats.copy()
|
|
31
|
+
|
|
32
|
+
def reset_pinned_transfer_stats():
|
|
33
|
+
"""Reset transfer statistics."""
|
|
34
|
+
global _pinned_transfer_stats
|
|
35
|
+
_pinned_transfer_stats = {"pinned": 0, "fallback": 0}
|
|
36
|
+
|
|
37
|
+
def transfer_to_gpu_pinned(
|
|
38
|
+
tensor,
|
|
39
|
+
device: str = 'cuda',
|
|
40
|
+
dtype = None
|
|
41
|
+
):
|
|
42
|
+
"""Transfer tensor to GPU using pinned memory for faster transfer."""
|
|
43
|
+
torch = _ensure_torch()
|
|
44
|
+
global _pinned_transfer_stats
|
|
45
|
+
|
|
46
|
+
# Skip if not a CPU tensor or CUDA unavailable
|
|
47
|
+
if tensor.device.type != 'cpu' or not torch.cuda.is_available():
|
|
48
|
+
if dtype is not None:
|
|
49
|
+
return tensor.to(device=device, dtype=dtype)
|
|
50
|
+
return tensor.to(device=device)
|
|
51
|
+
|
|
52
|
+
# Skip if target is not CUDA
|
|
53
|
+
if not str(device).startswith('cuda'):
|
|
54
|
+
if dtype is not None:
|
|
55
|
+
return tensor.to(device=device, dtype=dtype)
|
|
56
|
+
return tensor.to(device=device)
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
pinned = tensor.pin_memory()
|
|
60
|
+
|
|
61
|
+
if dtype is not None:
|
|
62
|
+
result = pinned.to(device=device, dtype=dtype, non_blocking=True)
|
|
63
|
+
else:
|
|
64
|
+
result = pinned.to(device=device, non_blocking=True)
|
|
65
|
+
|
|
66
|
+
torch.cuda.current_stream().synchronize()
|
|
67
|
+
|
|
68
|
+
# One-time confirmation on first success
|
|
69
|
+
if _pinned_transfer_stats["pinned"] == 0:
|
|
70
|
+
logger.debug("[pinned_transfer] Pinned memory active - faster GPU transfers enabled")
|
|
71
|
+
|
|
72
|
+
_pinned_transfer_stats["pinned"] += 1
|
|
73
|
+
if _verbose:
|
|
74
|
+
logger.debug(f"[pinned_transfer] Pinned: {tensor.shape} ({tensor.numel() * tensor.element_size() / 1024:.1f} KB)")
|
|
75
|
+
else:
|
|
76
|
+
logger.debug(f"[pinned_transfer] Transferred tensor {tensor.shape} to {device} via pinned memory")
|
|
77
|
+
|
|
78
|
+
return result
|
|
79
|
+
|
|
80
|
+
except Exception as e:
|
|
81
|
+
_pinned_transfer_stats["fallback"] += 1
|
|
82
|
+
if _verbose:
|
|
83
|
+
logger.debug(f"[pinned_transfer] Fallback: {e}")
|
|
84
|
+
else:
|
|
85
|
+
logger.debug(f"[pinned_transfer] Fallback transfer to {device} due to error: {e}")
|
|
86
|
+
|
|
87
|
+
if dtype is not None:
|
|
88
|
+
return tensor.to(device=device, dtype=dtype)
|
|
89
|
+
return tensor.to(device=device)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tensor utility functions.
|
|
3
|
+
|
|
4
|
+
Provides serialization helpers for dictionary/tensor conversion.
|
|
5
|
+
Requires `torch`.
|
|
6
|
+
"""
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Dict, Tuple
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
def _ensure_torch():
|
|
14
|
+
try:
|
|
15
|
+
import torch
|
|
16
|
+
return torch
|
|
17
|
+
except ImportError:
|
|
18
|
+
raise ImportError("The 'torch' package is required but not installed. Please install it.")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def dict_to_tensor(data_dict: dict):
|
|
22
|
+
"""
|
|
23
|
+
Convert a dictionary to a torch.uint8 tensor containing JSON bytes.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
data_dict: Dictionary to serialize
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
torch.uint8 tensor containing UTF-8 encoded JSON
|
|
30
|
+
"""
|
|
31
|
+
torch = _ensure_torch()
|
|
32
|
+
json_str = json.dumps(data_dict)
|
|
33
|
+
byte_data = json_str.encode("utf-8")
|
|
34
|
+
tensor_data = torch.tensor(list(byte_data), dtype=torch.uint8)
|
|
35
|
+
logger.debug(f"dict_to_tensor: serialized dict to uint8 tensor of shape {tensor_data.shape}")
|
|
36
|
+
return tensor_data
|
|
37
|
+
|
|
38
|
+
def tensor_to_dict(tensor_data) -> dict:
|
|
39
|
+
"""
|
|
40
|
+
Convert a torch.uint8 tensor containing JSON bytes to a dictionary.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
tensor_data: Tensor containing UTF-8 encoded JSON bytes
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Parsed dictionary
|
|
47
|
+
"""
|
|
48
|
+
if tensor_data.ndim != 1:
|
|
49
|
+
raise ValueError(f"Expected a 1D tensor for dict conversion, got {tensor_data.ndim}D tensor.")
|
|
50
|
+
byte_data = bytes(tensor_data.tolist())
|
|
51
|
+
json_str = byte_data.decode("utf-8")
|
|
52
|
+
data_dict = json.loads(json_str)
|
|
53
|
+
logger.debug(f"tensor_to_dict: deserialized tensor of shape {tensor_data.shape} to dict with keys: {list(data_dict.keys())}")
|
|
54
|
+
return data_dict
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: unifiedefficientloader
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: A unified interface for memory efficient per tensor loading of safetensors files as raw bytes from offset, handling CPU/GPU pinned transfers, and converting between tensors and dicts.
|
|
5
|
+
Author: silveroxides
|
|
6
|
+
License: MIT
|
|
7
|
+
Classifier: Development Status :: 4 - Beta
|
|
8
|
+
Classifier: Intended Audience :: Developers
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Requires-Python: >=3.9
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Provides-Extra: torch
|
|
16
|
+
Requires-Dist: torch; extra == "torch"
|
|
17
|
+
Provides-Extra: safetensors
|
|
18
|
+
Requires-Dist: safetensors; extra == "safetensors"
|
|
19
|
+
Provides-Extra: tqdm
|
|
20
|
+
Requires-Dist: tqdm; extra == "tqdm"
|
|
21
|
+
Provides-Extra: all
|
|
22
|
+
Requires-Dist: torch; extra == "all"
|
|
23
|
+
Requires-Dist: safetensors; extra == "all"
|
|
24
|
+
Requires-Dist: tqdm; extra == "all"
|
|
25
|
+
Dynamic: license-file
|
|
26
|
+
|
|
27
|
+
# unifiedefficientloader
|
|
28
|
+
|
|
29
|
+
A unified interface for loading safetensors, handling CPU/GPU pinned transfers, and converting between tensors and dicts.
|
|
30
|
+
|
|
31
|
+
## Installation
|
|
32
|
+
|
|
33
|
+
You can install this package via pip. Since it heavily relies on `torch` and `safetensors` but doesn't strictly force them as hard dependencies for package building/installation, make sure you have them installed in your environment:
|
|
34
|
+
|
|
35
|
+
```bash
|
|
36
|
+
pip install unifiedefficientloader
|
|
37
|
+
pip install torch safetensors tqdm
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## Usage
|
|
41
|
+
|
|
42
|
+
### Unified Safetensors Loader
|
|
43
|
+
|
|
44
|
+
```python
|
|
45
|
+
from unifiedefficientloader import UnifiedSafetensorsLoader
|
|
46
|
+
|
|
47
|
+
# Standard mode (preload all)
|
|
48
|
+
with UnifiedSafetensorsLoader("model.safetensors", low_memory=False) as loader:
|
|
49
|
+
tensor = loader.get_tensor("weight_name")
|
|
50
|
+
|
|
51
|
+
# Low memory mode (streaming)
|
|
52
|
+
with UnifiedSafetensorsLoader("model.safetensors", low_memory=True) as loader:
|
|
53
|
+
for key in loader.keys():
|
|
54
|
+
tensor = loader.get_tensor(key)
|
|
55
|
+
# Process tensor...
|
|
56
|
+
loader.mark_processed(key) # Frees memory
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
### Loading Specific Tensors Dynamically (Header Analysis)
|
|
60
|
+
|
|
61
|
+
You can analyze the file's header without loading the entire multi-gigabyte safetensors file into memory. This allows you to locate specific data (like embedded JSON dictionaries stored as `uint8` tensors) and load *only* those specific tensors directly from their file offsets.
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
from unifiedefficientloader import UnifiedSafetensorsLoader, tensor_to_dict
|
|
65
|
+
|
|
66
|
+
with UnifiedSafetensorsLoader("model.safetensors", low_memory=True) as loader:
|
|
67
|
+
# 1. Analyze the header metadata without loading any tensors
|
|
68
|
+
# loader._header contains the full safetensors header directory
|
|
69
|
+
uint8_tensor_keys = [
|
|
70
|
+
key for key, info in loader._header.items()
|
|
71
|
+
if isinstance(info, dict) and info.get("dtype") == "U8"
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
# 2. Load ONLY those specific tensors using their keys
|
|
75
|
+
for key in uint8_tensor_keys:
|
|
76
|
+
# get_tensor dynamically reads only the bytes for this tensor
|
|
77
|
+
# based on the offsets found in the header
|
|
78
|
+
loaded_tensor = loader.get_tensor(key)
|
|
79
|
+
|
|
80
|
+
# 3. Decode the uint8 tensor back into a Python dictionary
|
|
81
|
+
extracted_dict = tensor_to_dict(loaded_tensor)
|
|
82
|
+
print(f"Decoded {key}:", extracted_dict)
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
### Optimized Asynchronous Streaming via ThreadPoolExecutor
|
|
86
|
+
|
|
87
|
+
For maximum I/O throughput while maintaining strict memory backpressure, use `async_stream`. This utilizes a `ThreadPoolExecutor` for background disk reading and a bounded queue to prevent memory exhaustion. By setting `pin_memory=True`, memory pinning is performed sequentially in the main thread to avoid OS-level lock contention and preserve high DMA transfer speeds.
|
|
88
|
+
|
|
89
|
+
```python
|
|
90
|
+
from unifiedefficientloader import UnifiedSafetensorsLoader, transfer_to_gpu_pinned
|
|
91
|
+
|
|
92
|
+
with UnifiedSafetensorsLoader("model.safetensors", low_memory=True) as loader:
|
|
93
|
+
keys_to_load = loader.keys()
|
|
94
|
+
|
|
95
|
+
# Create the continuous streaming generator
|
|
96
|
+
# prefetch_batches controls how many batches to buffer in memory
|
|
97
|
+
stream = loader.async_stream(
|
|
98
|
+
keys_to_load,
|
|
99
|
+
batch_size=8,
|
|
100
|
+
prefetch_batches=2,
|
|
101
|
+
pin_memory=True
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Iterate directly over the generator
|
|
105
|
+
for batch in stream:
|
|
106
|
+
for key, pinned_tensor in batch:
|
|
107
|
+
# Transfer directly to GPU via DMA (pinning is already done)
|
|
108
|
+
gpu_tensor = transfer_to_gpu_pinned(pinned_tensor, device="cuda")
|
|
109
|
+
|
|
110
|
+
# ... process gpu_tensor ...
|
|
111
|
+
loader.mark_processed(key)
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
### Tensor/Dict Conversion
|
|
115
|
+
|
|
116
|
+
```python
|
|
117
|
+
from unifiedefficientloader import dict_to_tensor, tensor_to_dict
|
|
118
|
+
|
|
119
|
+
my_dict = {"param": 1.0, "name": "test"}
|
|
120
|
+
tensor = dict_to_tensor(my_dict)
|
|
121
|
+
recovered_dict = tensor_to_dict(tensor)
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
### Pinned Memory Transfers
|
|
125
|
+
|
|
126
|
+
```python
|
|
127
|
+
import torch
|
|
128
|
+
from unifiedefficientloader import transfer_to_gpu_pinned
|
|
129
|
+
|
|
130
|
+
tensor = torch.randn(100, 100)
|
|
131
|
+
# Transfers using pinned memory if CUDA is available, otherwise falls back gracefully
|
|
132
|
+
gpu_tensor = transfer_to_gpu_pinned(tensor, device="cuda:0")
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
unifiedefficientloader/__init__.py,sha256=1BJELlRKRDkK_GDhg75xSBqQn3CSmDuvlxcSYcG6Mjw,1133
|
|
2
|
+
unifiedefficientloader/memory_efficient_loader.py,sha256=4HE_mmjmde_7Z5L8gCc99AX2pN9byMjEvlcKjVwooso,13408
|
|
3
|
+
unifiedefficientloader/pinned_transfer.py,sha256=ppMxVc9BY1fSmBgC_PTIs0NlCRY2cHHo33ZIn2zfcLo,2915
|
|
4
|
+
unifiedefficientloader/tensor_utils.py,sha256=KiHiCPY8x97xrG-G_JyK-MKZsIVCNh44Wu_2oVA6mzc,1606
|
|
5
|
+
unifiedefficientloader-0.2.0.dist-info/licenses/LICENSE,sha256=A9N3lbMEmsGuLe9EjTiWkDUkWHrKRRh_xcdvS3eHN5g,1063
|
|
6
|
+
unifiedefficientloader-0.2.0.dist-info/METADATA,sha256=L1le1VCz46O6zxt4Xn4OVtunEg6Dr330lDKTouGgLr4,4999
|
|
7
|
+
unifiedefficientloader-0.2.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
8
|
+
unifiedefficientloader-0.2.0.dist-info/top_level.txt,sha256=6PqrT67C60EgKKXNbtP7o3TBEWxV2XrLxQloZnGyREA,23
|
|
9
|
+
unifiedefficientloader-0.2.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Silver
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
unifiedefficientloader
|