ztensor 1.1.1__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ztensor/__init__.py +710 -0
- ztensor/numpy.py +192 -0
- ztensor/torch.py +472 -0
- ztensor/ztensor/__init__.py +7 -0
- ztensor/ztensor/ffi.py +10 -0
- ztensor/ztensor/ztensor.dll +0 -0
- ztensor-1.1.1.dist-info/METADATA +274 -0
- ztensor-1.1.1.dist-info/RECORD +10 -0
- ztensor-1.1.1.dist-info/WHEEL +4 -0
- ztensor-1.1.1.dist-info/licenses/LICENSE +21 -0
ztensor/__init__.py
ADDED
|
@@ -0,0 +1,710 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import importlib
|
|
3
|
+
from .ztensor import ffi, lib
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
# --- Optional PyTorch Import ---
|
|
7
|
+
try:
|
|
8
|
+
_torch = importlib.import_module("torch")
|
|
9
|
+
TORCH_AVAILABLE = True
|
|
10
|
+
except ImportError:
|
|
11
|
+
_torch = None
|
|
12
|
+
TORCH_AVAILABLE = False
|
|
13
|
+
|
|
14
|
+
# --- Optional ml_dtypes for bfloat16 in NumPy ---
|
|
15
|
+
try:
|
|
16
|
+
from ml_dtypes import bfloat16 as np_bfloat16
|
|
17
|
+
ML_DTYPES_AVAILABLE = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
np_bfloat16 = None
|
|
20
|
+
ML_DTYPES_AVAILABLE = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# --- Pythonic Wrapper ---
|
|
24
|
+
class ZTensorError(Exception):
|
|
25
|
+
"""Custom exception for ztensor-related errors."""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _ZTensorView(np.ndarray):
|
|
30
|
+
"""Custom ndarray subclass to safely manage CFFI pointer lifetime."""
|
|
31
|
+
def __new__(cls, buffer, dtype, shape, view_ptr):
|
|
32
|
+
obj = np.frombuffer(buffer, dtype=dtype).reshape(shape).view(cls)
|
|
33
|
+
obj._owner = view_ptr
|
|
34
|
+
return obj
|
|
35
|
+
|
|
36
|
+
def __array_finalize__(self, obj):
|
|
37
|
+
if obj is None: return
|
|
38
|
+
self._owner = getattr(obj, '_owner', None)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _get_last_error():
|
|
42
|
+
"""Retrieves the last error message from the Rust library."""
|
|
43
|
+
err_msg_ptr = lib.ztensor_last_error_message()
|
|
44
|
+
if err_msg_ptr != ffi.NULL:
|
|
45
|
+
return ffi.string(err_msg_ptr).decode('utf-8')
|
|
46
|
+
return "Unknown FFI error"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _check_ptr(ptr, func_name=""):
|
|
50
|
+
"""Checks if a pointer from the FFI is null and raises an error if it is."""
|
|
51
|
+
if ptr == ffi.NULL:
|
|
52
|
+
raise ZTensorError(f"Error in {func_name}: {_get_last_error()}")
|
|
53
|
+
return ptr
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _check_status(status, func_name=""):
|
|
57
|
+
"""Checks the integer status code from an FFI call and raises on failure."""
|
|
58
|
+
if status != 0:
|
|
59
|
+
raise ZTensorError(f"Error in {func_name}: {_get_last_error()}")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# --- Type Mappings ---
|
|
63
|
+
DTYPE_NP_TO_ZT = {
|
|
64
|
+
np.dtype('float64'): 'float64', np.dtype('float32'): 'float32', np.dtype('float16'): 'float16',
|
|
65
|
+
np.dtype('int64'): 'int64', np.dtype('int32'): 'int32',
|
|
66
|
+
np.dtype('int16'): 'int16', np.dtype('int8'): 'int8',
|
|
67
|
+
np.dtype('uint64'): 'uint64', np.dtype('uint32'): 'uint32',
|
|
68
|
+
np.dtype('uint16'): 'uint16', np.dtype('uint8'): 'uint8',
|
|
69
|
+
np.dtype('bool'): 'bool',
|
|
70
|
+
}
|
|
71
|
+
if ML_DTYPES_AVAILABLE:
|
|
72
|
+
DTYPE_NP_TO_ZT[np.dtype(np_bfloat16)] = 'bfloat16'
|
|
73
|
+
DTYPE_ZT_TO_NP = {v: k for k, v in DTYPE_NP_TO_ZT.items()}
|
|
74
|
+
|
|
75
|
+
if TORCH_AVAILABLE:
|
|
76
|
+
DTYPE_TORCH_TO_ZT = {
|
|
77
|
+
_torch.float64: 'float64', _torch.float32: 'float32', _torch.float16: 'float16',
|
|
78
|
+
_torch.bfloat16: 'bfloat16',
|
|
79
|
+
_torch.int64: 'int64', _torch.int32: 'int32',
|
|
80
|
+
_torch.int16: 'int16', _torch.int8: 'int8',
|
|
81
|
+
_torch.uint8: 'uint8', _torch.bool: 'bool',
|
|
82
|
+
}
|
|
83
|
+
DTYPE_ZT_TO_TORCH = {v: k for k, v in DTYPE_TORCH_TO_ZT.items()}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class TensorMetadata:
|
|
87
|
+
"""A Pythonic wrapper around the CTensorMetadata pointer."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, meta_ptr):
|
|
90
|
+
self._ptr = ffi.gc(meta_ptr, lib.ztensor_metadata_free)
|
|
91
|
+
_check_ptr(self._ptr, "TensorMetadata constructor")
|
|
92
|
+
self._name = None
|
|
93
|
+
self._dtype_str = None
|
|
94
|
+
self._shape = None
|
|
95
|
+
self._offset = None
|
|
96
|
+
self._size = None
|
|
97
|
+
self._layout = None
|
|
98
|
+
self._encoding = None
|
|
99
|
+
self._endianness = "not_checked"
|
|
100
|
+
self._checksum = "not_checked"
|
|
101
|
+
|
|
102
|
+
def __repr__(self):
|
|
103
|
+
return f"<TensorMetadata name='{self.name}' shape={self.shape} dtype='{self.dtype_str}'>"
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def name(self):
|
|
107
|
+
"""The name of the tensor."""
|
|
108
|
+
if self._name is None:
|
|
109
|
+
name_ptr = lib.ztensor_metadata_get_name(self._ptr)
|
|
110
|
+
_check_ptr(name_ptr, "get_name")
|
|
111
|
+
self._name = ffi.string(name_ptr).decode('utf-8')
|
|
112
|
+
lib.ztensor_free_string(name_ptr)
|
|
113
|
+
return self._name
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def dtype_str(self):
|
|
117
|
+
"""The zTensor dtype string (e.g., 'float32')."""
|
|
118
|
+
if self._dtype_str is None:
|
|
119
|
+
dtype_ptr = lib.ztensor_metadata_get_dtype_str(self._ptr)
|
|
120
|
+
_check_ptr(dtype_ptr, "get_dtype_str")
|
|
121
|
+
self._dtype_str = ffi.string(dtype_ptr).decode('utf-8')
|
|
122
|
+
lib.ztensor_free_string(dtype_ptr)
|
|
123
|
+
return self._dtype_str
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def dtype(self):
|
|
127
|
+
"""The numpy dtype for this tensor."""
|
|
128
|
+
dtype_str = self.dtype_str
|
|
129
|
+
dt = DTYPE_ZT_TO_NP.get(dtype_str)
|
|
130
|
+
if dt is None:
|
|
131
|
+
if dtype_str == 'bfloat16':
|
|
132
|
+
raise ZTensorError(
|
|
133
|
+
"Cannot read 'bfloat16' tensor as NumPy array because the 'ml_dtypes' "
|
|
134
|
+
"package is not installed. Please install it to proceed."
|
|
135
|
+
)
|
|
136
|
+
raise ZTensorError(f"Unsupported or unknown dtype string '{dtype_str}' found in tensor metadata.")
|
|
137
|
+
return dt
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def shape(self):
|
|
141
|
+
"""The shape of the tensor as a tuple."""
|
|
142
|
+
if self._shape is None:
|
|
143
|
+
shape_len = lib.ztensor_metadata_get_shape_len(self._ptr)
|
|
144
|
+
if shape_len > 0:
|
|
145
|
+
shape_data_ptr = lib.ztensor_metadata_get_shape_data(self._ptr)
|
|
146
|
+
_check_ptr(shape_data_ptr, "get_shape_data")
|
|
147
|
+
self._shape = tuple(shape_data_ptr[i] for i in range(shape_len))
|
|
148
|
+
lib.ztensor_free_u64_array(shape_data_ptr, shape_len)
|
|
149
|
+
else:
|
|
150
|
+
self._shape = tuple()
|
|
151
|
+
return self._shape
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def offset(self):
|
|
155
|
+
"""The on-disk offset of the tensor data in bytes."""
|
|
156
|
+
if self._offset is None:
|
|
157
|
+
self._offset = lib.ztensor_metadata_get_offset(self._ptr)
|
|
158
|
+
return self._offset
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def size(self):
|
|
162
|
+
"""The on-disk size of the tensor data in bytes (can be compressed size)."""
|
|
163
|
+
if self._size is None:
|
|
164
|
+
self._size = lib.ztensor_metadata_get_size(self._ptr)
|
|
165
|
+
return self._size
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def layout(self):
|
|
169
|
+
"""The tensor layout as a string (e.g., 'dense')."""
|
|
170
|
+
if self._layout is None:
|
|
171
|
+
layout_ptr = lib.ztensor_metadata_get_layout_str(self._ptr)
|
|
172
|
+
_check_ptr(layout_ptr, "get_layout_str")
|
|
173
|
+
self._layout = ffi.string(layout_ptr).decode('utf-8')
|
|
174
|
+
lib.ztensor_free_string(layout_ptr)
|
|
175
|
+
return self._layout
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def encoding(self):
|
|
179
|
+
"""The tensor encoding as a string (e.g., 'raw', 'zstd')."""
|
|
180
|
+
if self._encoding is None:
|
|
181
|
+
encoding_ptr = lib.ztensor_metadata_get_encoding_str(self._ptr)
|
|
182
|
+
if encoding_ptr == ffi.NULL:
|
|
183
|
+
self._encoding = None
|
|
184
|
+
else:
|
|
185
|
+
self._encoding = ffi.string(encoding_ptr).decode('utf-8')
|
|
186
|
+
lib.ztensor_free_string(encoding_ptr)
|
|
187
|
+
return self._encoding
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def endianness(self):
|
|
191
|
+
"""The data endianness ('little', 'big') if applicable, else None."""
|
|
192
|
+
if self._endianness == "not_checked":
|
|
193
|
+
endian_ptr = lib.ztensor_metadata_get_data_endianness_str(self._ptr)
|
|
194
|
+
if endian_ptr == ffi.NULL:
|
|
195
|
+
self._endianness = None
|
|
196
|
+
else:
|
|
197
|
+
self._endianness = ffi.string(endian_ptr).decode('utf-8')
|
|
198
|
+
lib.ztensor_free_string(endian_ptr)
|
|
199
|
+
return self._endianness
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def checksum(self):
|
|
203
|
+
"""The checksum string if present, else None."""
|
|
204
|
+
if self._checksum == "not_checked":
|
|
205
|
+
checksum_ptr = lib.ztensor_metadata_get_checksum_str(self._ptr)
|
|
206
|
+
if checksum_ptr == ffi.NULL:
|
|
207
|
+
self._checksum = None
|
|
208
|
+
else:
|
|
209
|
+
self._checksum = ffi.string(checksum_ptr).decode('utf-8')
|
|
210
|
+
lib.ztensor_free_string(checksum_ptr)
|
|
211
|
+
return self._checksum
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class Reader:
|
|
215
|
+
"""A Pythonic context manager for reading zTensor files."""
|
|
216
|
+
|
|
217
|
+
def __init__(self, file_path):
|
|
218
|
+
path_bytes = file_path.encode('utf-8')
|
|
219
|
+
ptr = lib.ztensor_reader_open(path_bytes)
|
|
220
|
+
_check_ptr(ptr, f"Reader open: {file_path}")
|
|
221
|
+
self._ptr = ffi.gc(ptr, lib.ztensor_reader_free)
|
|
222
|
+
self._tensor_names_cache = None
|
|
223
|
+
|
|
224
|
+
def __enter__(self):
|
|
225
|
+
return self
|
|
226
|
+
|
|
227
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
228
|
+
self._ptr = None
|
|
229
|
+
self._tensor_names_cache = None
|
|
230
|
+
|
|
231
|
+
def __len__(self):
|
|
232
|
+
"""Returns the number of tensors in the file."""
|
|
233
|
+
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
234
|
+
return lib.ztensor_reader_get_metadata_count(self._ptr)
|
|
235
|
+
|
|
236
|
+
def __iter__(self):
|
|
237
|
+
"""Iterates over the metadata of all tensors in the file."""
|
|
238
|
+
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
239
|
+
for name in self.tensor_names:
|
|
240
|
+
yield self.metadata(name)
|
|
241
|
+
|
|
242
|
+
def __getitem__(self, key):
|
|
243
|
+
"""
|
|
244
|
+
Retrieves metadata (int key) or reads tensor data (str key).
|
|
245
|
+
"""
|
|
246
|
+
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
247
|
+
|
|
248
|
+
if isinstance(key, int):
|
|
249
|
+
if key >= len(self):
|
|
250
|
+
raise IndexError("Tensor index out of range")
|
|
251
|
+
meta_ptr = lib.ztensor_reader_get_metadata_by_index(self._ptr, key)
|
|
252
|
+
_check_ptr(meta_ptr, f"get_metadata_by_index: {key}")
|
|
253
|
+
return TensorMetadata(meta_ptr)
|
|
254
|
+
elif isinstance(key, str):
|
|
255
|
+
return self.read_tensor(key)
|
|
256
|
+
else:
|
|
257
|
+
raise TypeError(f"Invalid argument type for __getitem__: {type(key)}")
|
|
258
|
+
|
|
259
|
+
def __contains__(self, name: str) -> bool:
|
|
260
|
+
"""Checks if a tensor with the given name exists in the file."""
|
|
261
|
+
return name in self.tensor_names
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def tensors(self) -> list[TensorMetadata]:
|
|
265
|
+
"""Returns a list of all TensorMetadata objects in the file."""
|
|
266
|
+
return list(self)
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def tensor_names(self) -> list[str]:
|
|
270
|
+
"""Returns a list of all tensor names in the file (cached)."""
|
|
271
|
+
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
272
|
+
if self._tensor_names_cache is not None:
|
|
273
|
+
return self._tensor_names_cache
|
|
274
|
+
|
|
275
|
+
c_array_ptr = lib.ztensor_reader_get_all_tensor_names(self._ptr)
|
|
276
|
+
_check_ptr(c_array_ptr, "get_all_tensor_names")
|
|
277
|
+
c_array_ptr = ffi.gc(c_array_ptr, lib.ztensor_free_string_array)
|
|
278
|
+
|
|
279
|
+
self._tensor_names_cache = [ffi.string(c_array_ptr.strings[i]).decode('utf-8') for i in range(c_array_ptr.len)]
|
|
280
|
+
return self._tensor_names_cache
|
|
281
|
+
|
|
282
|
+
def metadata(self, name: str) -> TensorMetadata:
|
|
283
|
+
"""Retrieves metadata for a tensor by its name."""
|
|
284
|
+
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
285
|
+
name_bytes = name.encode('utf-8')
|
|
286
|
+
meta_ptr = lib.ztensor_reader_get_metadata_by_name(self._ptr, name_bytes)
|
|
287
|
+
_check_ptr(meta_ptr, f"metadata: {name}")
|
|
288
|
+
return TensorMetadata(meta_ptr)
|
|
289
|
+
|
|
290
|
+
# Legacy aliases
|
|
291
|
+
def list_tensors(self): return self.tensors
|
|
292
|
+
def get_tensor_names(self): return self.tensor_names
|
|
293
|
+
def get_metadata(self, name): return self.metadata(name)
|
|
294
|
+
|
|
295
|
+
def _read_component(self, tensor_name: str, component_name: str, dtype_func):
|
|
296
|
+
"""Reads a specific component as a numpy array."""
|
|
297
|
+
t_name_bytes = tensor_name.encode('utf-8')
|
|
298
|
+
c_name_bytes = component_name.encode('utf-8')
|
|
299
|
+
|
|
300
|
+
view_ptr = lib.ztensor_reader_read_tensor_component(self._ptr, t_name_bytes, c_name_bytes)
|
|
301
|
+
_check_ptr(view_ptr, f"read_component: {tensor_name}.{component_name}")
|
|
302
|
+
view_ptr = ffi.gc(view_ptr, lib.ztensor_free_tensor_view)
|
|
303
|
+
|
|
304
|
+
buffer = ffi.buffer(view_ptr.data, view_ptr.len)
|
|
305
|
+
arr = np.frombuffer(buffer, dtype=dtype_func())
|
|
306
|
+
return arr, view_ptr
|
|
307
|
+
|
|
308
|
+
def read_tensor(self, name: str, to: str = 'numpy', verify_checksum: bool = False):
|
|
309
|
+
"""
|
|
310
|
+
Reads a tensor by name and returns it as a NumPy array or PyTorch tensor.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
name (str): The name of the tensor to read.
|
|
314
|
+
to (str): The desired output format. 'numpy' or 'torch'.
|
|
315
|
+
verify_checksum (bool): If True, verify checksums during read (slower). Default: False.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
Union[np.ndarray, torch.Tensor, scipy.sparse.spmatrix, torch.sparse_coo_tensor]
|
|
319
|
+
"""
|
|
320
|
+
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
321
|
+
if to not in ['numpy', 'torch']:
|
|
322
|
+
raise ValueError(f"Unsupported format: '{to}'.")
|
|
323
|
+
|
|
324
|
+
metadata = self.metadata(name)
|
|
325
|
+
layout = metadata.layout
|
|
326
|
+
|
|
327
|
+
if layout == "dense":
|
|
328
|
+
# Optimization: Try zero-copy for raw tensors
|
|
329
|
+
name_bytes = name.encode('utf-8')
|
|
330
|
+
view_ptr = ffi.NULL
|
|
331
|
+
is_zero_copy = False
|
|
332
|
+
|
|
333
|
+
if metadata.encoding == 'raw':
|
|
334
|
+
# Try to get a zero-copy slice
|
|
335
|
+
try:
|
|
336
|
+
view_ptr = lib.ztensor_reader_get_tensor_slice(self._ptr, name_bytes)
|
|
337
|
+
except AttributeError:
|
|
338
|
+
# Fallback if library doesn't have the function yet (e.g. old build)
|
|
339
|
+
pass
|
|
340
|
+
|
|
341
|
+
if view_ptr != ffi.NULL:
|
|
342
|
+
is_zero_copy = True
|
|
343
|
+
|
|
344
|
+
if not is_zero_copy:
|
|
345
|
+
# Fallback to standard read (copy)
|
|
346
|
+
view_ptr = lib.ztensor_reader_read_tensor(self._ptr, name_bytes, 1 if verify_checksum else 0)
|
|
347
|
+
_check_ptr(view_ptr, f"read_tensor: {name}")
|
|
348
|
+
|
|
349
|
+
view_ptr = ffi.gc(view_ptr, lib.ztensor_free_tensor_view)
|
|
350
|
+
|
|
351
|
+
if to == 'numpy':
|
|
352
|
+
arr = _ZTensorView(
|
|
353
|
+
buffer=ffi.buffer(view_ptr.data, view_ptr.len),
|
|
354
|
+
dtype=metadata.dtype,
|
|
355
|
+
shape=metadata.shape,
|
|
356
|
+
view_ptr=view_ptr
|
|
357
|
+
)
|
|
358
|
+
if is_zero_copy:
|
|
359
|
+
arr._reader_ref = self
|
|
360
|
+
return arr
|
|
361
|
+
|
|
362
|
+
elif to == 'torch':
|
|
363
|
+
if not TORCH_AVAILABLE: raise ZTensorError("PyTorch not installed.")
|
|
364
|
+
torch_dtype = DTYPE_ZT_TO_TORCH.get(metadata.dtype_str)
|
|
365
|
+
buffer = ffi.buffer(view_ptr.data, view_ptr.len)
|
|
366
|
+
|
|
367
|
+
# Note: frombuffer normally shares memory.
|
|
368
|
+
# If the buffer is read-only (which mmap might be), PyTorch might copy if it needs mutable tensor.
|
|
369
|
+
# zTensor mmap is read-only usually, so this should be fine for inference.
|
|
370
|
+
torch_tensor = _torch.frombuffer(buffer, dtype=torch_dtype).reshape(metadata.shape)
|
|
371
|
+
torch_tensor._owner = view_ptr
|
|
372
|
+
if is_zero_copy:
|
|
373
|
+
torch_tensor._reader_ref = self
|
|
374
|
+
return torch_tensor
|
|
375
|
+
|
|
376
|
+
elif layout == "sparse_csr":
|
|
377
|
+
vals, v_ref = self._read_component(name, "values", lambda: metadata.dtype)
|
|
378
|
+
idxs, i_ref = self._read_component(name, "indices", lambda: np.uint64)
|
|
379
|
+
ptrs, p_ref = self._read_component(name, "indptr", lambda: np.uint64)
|
|
380
|
+
|
|
381
|
+
if to == 'numpy':
|
|
382
|
+
try:
|
|
383
|
+
from scipy.sparse import csr_matrix
|
|
384
|
+
except ImportError:
|
|
385
|
+
raise ZTensorError("scipy is required for reading sparse tensors as numpy.")
|
|
386
|
+
# Create CSR matrix with zero-copy views
|
|
387
|
+
result = csr_matrix((vals, idxs.astype(np.int32), ptrs.astype(np.int32)), shape=metadata.shape)
|
|
388
|
+
# Keep FFI view pointers alive for the lifetime of the sparse matrix
|
|
389
|
+
result._ztensor_owners = (v_ref, i_ref, p_ref)
|
|
390
|
+
return result
|
|
391
|
+
|
|
392
|
+
elif to == 'torch':
|
|
393
|
+
if not TORCH_AVAILABLE: raise ZTensorError("No Torch.")
|
|
394
|
+
# PyTorch sparse tensors copy internally, so we need to copy data
|
|
395
|
+
t_vals = _torch.from_numpy(vals.copy())
|
|
396
|
+
t_indptr = _torch.from_numpy(ptrs.astype(np.int64))
|
|
397
|
+
t_indices = _torch.from_numpy(idxs.astype(np.int64))
|
|
398
|
+
return _torch.sparse_csr_tensor(t_indptr, t_indices, t_vals, size=metadata.shape)
|
|
399
|
+
|
|
400
|
+
elif layout == "sparse_coo":
|
|
401
|
+
vals, v_ref = self._read_component(name, "values", lambda: metadata.dtype)
|
|
402
|
+
coords, c_ref = self._read_component(name, "coords", lambda: np.uint64)
|
|
403
|
+
|
|
404
|
+
nnz = vals.shape[0]
|
|
405
|
+
ndim = len(metadata.shape)
|
|
406
|
+
coords = coords.reshape((ndim, nnz))
|
|
407
|
+
|
|
408
|
+
if to == 'numpy':
|
|
409
|
+
if ndim != 2: raise ZTensorError("Scipy COO only supports 2D.")
|
|
410
|
+
from scipy.sparse import coo_matrix
|
|
411
|
+
# Create COO matrix with zero-copy views
|
|
412
|
+
result = coo_matrix((vals, (coords[0], coords[1])), shape=metadata.shape)
|
|
413
|
+
# Keep FFI view pointers alive for the lifetime of the sparse matrix
|
|
414
|
+
result._ztensor_owners = (v_ref, c_ref)
|
|
415
|
+
return result
|
|
416
|
+
|
|
417
|
+
elif to == 'torch':
|
|
418
|
+
if not TORCH_AVAILABLE: raise ZTensorError("No Torch.")
|
|
419
|
+
# PyTorch sparse tensors copy internally
|
|
420
|
+
t_vals = _torch.from_numpy(vals.copy())
|
|
421
|
+
t_indices = _torch.from_numpy(coords.astype(np.int64))
|
|
422
|
+
return _torch.sparse_coo_tensor(t_indices, t_vals, size=metadata.shape)
|
|
423
|
+
|
|
424
|
+
else:
|
|
425
|
+
raise ZTensorError(f"Unsupported layout: {layout}")
|
|
426
|
+
|
|
427
|
+
def read_tensors(self, names: list[str], to: str = 'numpy', verify_checksum: bool = False) -> list:
|
|
428
|
+
"""
|
|
429
|
+
Reads multiple tensors in batch.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
names (list[str]): List of tensor names to read.
|
|
433
|
+
to (str): The desired output format. 'numpy' or 'torch'.
|
|
434
|
+
verify_checksum (bool): If True, verify checksums during read. Default: False.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
List of tensors in the same order as input names.
|
|
438
|
+
"""
|
|
439
|
+
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
440
|
+
if to not in ['numpy', 'torch']:
|
|
441
|
+
raise ValueError(f"Unsupported format: '{to}'.")
|
|
442
|
+
|
|
443
|
+
# For dense tensors, use the batch FFI function for efficiency
|
|
444
|
+
# For sparse tensors, fall back to individual reads
|
|
445
|
+
results = []
|
|
446
|
+
dense_names = []
|
|
447
|
+
dense_indices = []
|
|
448
|
+
|
|
449
|
+
# First pass: collect metadata and identify dense tensors
|
|
450
|
+
metadatas = []
|
|
451
|
+
for i, name in enumerate(names):
|
|
452
|
+
meta = self.metadata(name)
|
|
453
|
+
metadatas.append(meta)
|
|
454
|
+
if meta.layout == "dense":
|
|
455
|
+
dense_names.append(name)
|
|
456
|
+
dense_indices.append(i)
|
|
457
|
+
|
|
458
|
+
# Batch read all dense tensors
|
|
459
|
+
dense_results = {}
|
|
460
|
+
if dense_names:
|
|
461
|
+
name_ptrs = [ffi.new("char[]", n.encode('utf-8')) for n in dense_names]
|
|
462
|
+
name_ptr_array = ffi.new("char*[]", name_ptrs)
|
|
463
|
+
|
|
464
|
+
arr_ptr = lib.ztensor_reader_read_tensors(
|
|
465
|
+
self._ptr,
|
|
466
|
+
name_ptr_array,
|
|
467
|
+
len(dense_names),
|
|
468
|
+
1 if verify_checksum else 0
|
|
469
|
+
)
|
|
470
|
+
_check_ptr(arr_ptr, "read_tensors batch")
|
|
471
|
+
# Keep arr_ptr alive - it owns all the views
|
|
472
|
+
arr_ptr = ffi.gc(arr_ptr, lib.ztensor_free_tensor_view_array)
|
|
473
|
+
|
|
474
|
+
for idx, name in enumerate(dense_names):
|
|
475
|
+
view = arr_ptr.views[idx]
|
|
476
|
+
meta = metadatas[dense_indices[idx]]
|
|
477
|
+
|
|
478
|
+
if to == 'numpy':
|
|
479
|
+
# Zero-copy: _ZTensorView keeps arr_ptr alive via view_ptr reference
|
|
480
|
+
arr = _ZTensorView(
|
|
481
|
+
buffer=ffi.buffer(view.data, view.len),
|
|
482
|
+
dtype=meta.dtype,
|
|
483
|
+
shape=meta.shape,
|
|
484
|
+
view_ptr=arr_ptr # Keep the batch array alive
|
|
485
|
+
)
|
|
486
|
+
# If this is a zero-copy view (owner is null), we must keep Reader alive
|
|
487
|
+
if view._owner == ffi.NULL:
|
|
488
|
+
arr._reader_ref = self
|
|
489
|
+
dense_results[name] = arr
|
|
490
|
+
elif to == 'torch':
|
|
491
|
+
if not TORCH_AVAILABLE: raise ZTensorError("PyTorch not installed.")
|
|
492
|
+
torch_dtype = DTYPE_ZT_TO_TORCH.get(meta.dtype_str)
|
|
493
|
+
buffer = ffi.buffer(view.data, view.len)
|
|
494
|
+
tensor = _torch.frombuffer(buffer, dtype=torch_dtype).reshape(meta.shape)
|
|
495
|
+
tensor._owner = arr_ptr # Keep the batch array alive
|
|
496
|
+
|
|
497
|
+
if view._owner == ffi.NULL:
|
|
498
|
+
tensor._reader_ref = self
|
|
499
|
+
dense_results[name] = tensor
|
|
500
|
+
|
|
501
|
+
# Assemble final results
|
|
502
|
+
for i, name in enumerate(names):
|
|
503
|
+
if name in dense_results:
|
|
504
|
+
results.append(dense_results[name])
|
|
505
|
+
else:
|
|
506
|
+
# Sparse tensor - fall back to individual read
|
|
507
|
+
results.append(self.read_tensor(name, to=to, verify_checksum=verify_checksum))
|
|
508
|
+
|
|
509
|
+
return results
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
class Writer:
|
|
513
|
+
"""A Pythonic context manager for writing zTensor files."""
|
|
514
|
+
|
|
515
|
+
def __init__(self, file_path):
|
|
516
|
+
path_bytes = file_path.encode('utf-8')
|
|
517
|
+
ptr = lib.ztensor_writer_create(path_bytes)
|
|
518
|
+
_check_ptr(ptr, f"Writer create: {file_path}")
|
|
519
|
+
self._ptr = ptr
|
|
520
|
+
self._finalized = False
|
|
521
|
+
|
|
522
|
+
def __enter__(self):
|
|
523
|
+
return self
|
|
524
|
+
|
|
525
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
526
|
+
if self._ptr and not self._finalized:
|
|
527
|
+
if exc_type is None:
|
|
528
|
+
self.finalize()
|
|
529
|
+
else:
|
|
530
|
+
lib.ztensor_writer_free(self._ptr)
|
|
531
|
+
self._ptr = None
|
|
532
|
+
|
|
533
|
+
def add_tensor(self, name: str, tensor, compress: Union[bool, int] = False):
|
|
534
|
+
"""
|
|
535
|
+
Adds a NumPy or PyTorch tensor to the file.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
name (str): The name of the tensor to add.
|
|
539
|
+
tensor (np.ndarray or torch.Tensor): The tensor data to write.
|
|
540
|
+
compress (bool or int): Compression settings.
|
|
541
|
+
- False / 0: No compression (Raw).
|
|
542
|
+
- True: Default Zstd compression (Level 3).
|
|
543
|
+
- int > 0: Specific Zstd compression level (e.g., 1-22).
|
|
544
|
+
"""
|
|
545
|
+
if not self._ptr: raise ZTensorError("Writer is closed or finalized.")
|
|
546
|
+
|
|
547
|
+
# Resolve compression level
|
|
548
|
+
compression_level = 0
|
|
549
|
+
if compress is True:
|
|
550
|
+
compression_level = 3 # Default zstd level
|
|
551
|
+
elif compress is False or compress is None:
|
|
552
|
+
compression_level = 0
|
|
553
|
+
elif isinstance(compress, int):
|
|
554
|
+
compression_level = compress
|
|
555
|
+
else:
|
|
556
|
+
raise TypeError(f"Invalid type for 'compress': {type(compress)}. Expected bool or int.")
|
|
557
|
+
|
|
558
|
+
if isinstance(tensor, np.ndarray):
|
|
559
|
+
tensor = np.ascontiguousarray(tensor)
|
|
560
|
+
shape = tensor.shape
|
|
561
|
+
dtype_str = DTYPE_NP_TO_ZT.get(tensor.dtype)
|
|
562
|
+
data_ptr = ffi.cast("unsigned char*", tensor.ctypes.data)
|
|
563
|
+
nbytes = tensor.nbytes
|
|
564
|
+
|
|
565
|
+
elif TORCH_AVAILABLE and isinstance(tensor, _torch.Tensor):
|
|
566
|
+
if tensor.is_cuda:
|
|
567
|
+
raise ZTensorError("Cannot write directly from a CUDA tensor. Copy to CPU first using .cpu().")
|
|
568
|
+
tensor = tensor.contiguous()
|
|
569
|
+
shape = tuple(tensor.shape)
|
|
570
|
+
dtype_str = DTYPE_TORCH_TO_ZT.get(tensor.dtype)
|
|
571
|
+
data_ptr = ffi.cast("unsigned char*", tensor.data_ptr())
|
|
572
|
+
nbytes = tensor.numel() * tensor.element_size()
|
|
573
|
+
|
|
574
|
+
else:
|
|
575
|
+
supported = "np.ndarray" + (" or torch.Tensor" if TORCH_AVAILABLE else "")
|
|
576
|
+
raise TypeError(f"Unsupported tensor type: {type(tensor)}. Must be {supported}.")
|
|
577
|
+
|
|
578
|
+
if not dtype_str:
|
|
579
|
+
msg = f"Unsupported dtype: {tensor.dtype}."
|
|
580
|
+
if 'bfloat16' in str(tensor.dtype) and not ML_DTYPES_AVAILABLE:
|
|
581
|
+
msg += " For NumPy bfloat16 support, please install the 'ml_dtypes' package."
|
|
582
|
+
raise ZTensorError(msg)
|
|
583
|
+
|
|
584
|
+
name_bytes = name.encode('utf-8')
|
|
585
|
+
shape_array = np.array(shape, dtype=np.uint64)
|
|
586
|
+
shape_ptr = ffi.cast("uint64_t*", shape_array.ctypes.data)
|
|
587
|
+
dtype_bytes = dtype_str.encode('utf-8')
|
|
588
|
+
|
|
589
|
+
status = lib.ztensor_writer_add_tensor(
|
|
590
|
+
self._ptr, name_bytes, shape_ptr, len(shape),
|
|
591
|
+
dtype_bytes, data_ptr, nbytes, compression_level
|
|
592
|
+
)
|
|
593
|
+
_check_status(status, f"add_tensor: {name}")
|
|
594
|
+
|
|
595
|
+
# Keep alive: data_ptr is derived from tensor, so keeping tensor alive is enough?
|
|
596
|
+
# Actually ffi.cast doesn't keep alive. We need to ensure 'tensor' stays alive until after the call.
|
|
597
|
+
# Python arguments are kept alive during the call, so this is safe.
|
|
598
|
+
|
|
599
|
+
def add_sparse_csr(self, name: str, values, indices, indptr, shape):
|
|
600
|
+
"""Adds a sparse CSR tensor."""
|
|
601
|
+
if not self._ptr: raise ZTensorError("Writer is closed.")
|
|
602
|
+
|
|
603
|
+
# Keep all arrays alive during FFI call
|
|
604
|
+
keepalive = []
|
|
605
|
+
|
|
606
|
+
def get_buffer_info(arr, force_dtype=None):
|
|
607
|
+
if isinstance(arr, np.ndarray):
|
|
608
|
+
if force_dtype is not None and arr.dtype != force_dtype:
|
|
609
|
+
arr = arr.astype(force_dtype)
|
|
610
|
+
arr = np.ascontiguousarray(arr)
|
|
611
|
+
keepalive.append(arr) # Prevent GC
|
|
612
|
+
ptr = ffi.cast("unsigned char*", arr.ctypes.data)
|
|
613
|
+
length = arr.nbytes
|
|
614
|
+
count = arr.size
|
|
615
|
+
dtype_str = DTYPE_NP_TO_ZT.get(arr.dtype)
|
|
616
|
+
return ptr, length, count, dtype_str
|
|
617
|
+
elif TORCH_AVAILABLE and isinstance(arr, _torch.Tensor):
|
|
618
|
+
if arr.is_cuda: raise ZTensorError("CUDA tensors not supported. Move to CPU.")
|
|
619
|
+
arr = arr.contiguous()
|
|
620
|
+
keepalive.append(arr) # Prevent GC
|
|
621
|
+
ptr = ffi.cast("unsigned char*", arr.data_ptr())
|
|
622
|
+
length = arr.numel() * arr.element_size()
|
|
623
|
+
count = arr.numel()
|
|
624
|
+
dtype_str = DTYPE_TORCH_TO_ZT.get(arr.dtype)
|
|
625
|
+
return ptr, length, count, dtype_str
|
|
626
|
+
else:
|
|
627
|
+
raise TypeError(f"Unsupported array type: {type(arr)}")
|
|
628
|
+
|
|
629
|
+
# Get values as bytes
|
|
630
|
+
v_ptr, v_len, _, v_dtype = get_buffer_info(values)
|
|
631
|
+
|
|
632
|
+
# Convert indices and indptr to uint64
|
|
633
|
+
i_ptr, _, i_cnt, _ = get_buffer_info(indices, force_dtype=np.uint64)
|
|
634
|
+
p_ptr, _, p_cnt, _ = get_buffer_info(indptr, force_dtype=np.uint64)
|
|
635
|
+
|
|
636
|
+
# Cast to u64* for FFI
|
|
637
|
+
i_ptr_u64 = ffi.cast("uint64_t*", ffi.cast("unsigned char*", i_ptr))
|
|
638
|
+
p_ptr_u64 = ffi.cast("uint64_t*", ffi.cast("unsigned char*", p_ptr))
|
|
639
|
+
|
|
640
|
+
name_bytes = name.encode('utf-8')
|
|
641
|
+
shape_array = np.array(shape, dtype=np.uint64)
|
|
642
|
+
keepalive.append(shape_array)
|
|
643
|
+
shape_ptr = ffi.cast("uint64_t*", shape_array.ctypes.data)
|
|
644
|
+
dtype_bytes = v_dtype.encode('utf-8')
|
|
645
|
+
|
|
646
|
+
status = lib.ztensor_writer_add_sparse_csr(
|
|
647
|
+
self._ptr,
|
|
648
|
+
name_bytes,
|
|
649
|
+
shape_ptr, len(shape),
|
|
650
|
+
dtype_bytes,
|
|
651
|
+
v_ptr, v_len,
|
|
652
|
+
i_ptr_u64, i_cnt,
|
|
653
|
+
p_ptr_u64, p_cnt,
|
|
654
|
+
)
|
|
655
|
+
_check_status(status, f"add_sparse_csr: {name}")
|
|
656
|
+
|
|
657
|
+
def add_sparse_coo(self, name: str, values, indices, shape):
|
|
658
|
+
"""Adds a sparse COO tensor."""
|
|
659
|
+
if not self._ptr: raise ZTensorError("Writer is closed.")
|
|
660
|
+
|
|
661
|
+
def get_buffer_info(arr):
|
|
662
|
+
if isinstance(arr, np.ndarray):
|
|
663
|
+
arr = np.ascontiguousarray(arr)
|
|
664
|
+
return arr, ffi.cast("unsigned char*", arr.ctypes.data), arr.nbytes, DTYPE_NP_TO_ZT.get(arr.dtype)
|
|
665
|
+
elif TORCH_AVAILABLE and isinstance(arr, _torch.Tensor):
|
|
666
|
+
if arr.is_cuda: raise ZTensorError("No CUDA")
|
|
667
|
+
arr = arr.contiguous()
|
|
668
|
+
return arr, ffi.cast("unsigned char*", arr.data_ptr()), arr.numel() * arr.element_size(), DTYPE_TORCH_TO_ZT.get(arr.dtype)
|
|
669
|
+
raise TypeError("Unsupported")
|
|
670
|
+
|
|
671
|
+
v_arr, v_ptr, v_len, v_dtype = get_buffer_info(values)
|
|
672
|
+
|
|
673
|
+
def ensure_u64_ptr(arr):
|
|
674
|
+
if isinstance(arr, np.ndarray):
|
|
675
|
+
if arr.dtype != np.uint64: arr = arr.astype(np.uint64)
|
|
676
|
+
arr = np.ascontiguousarray(arr)
|
|
677
|
+
return ffi.cast("uint64_t*", arr.ctypes.data), arr, arr.size
|
|
678
|
+
elif TORCH_AVAILABLE and isinstance(arr, _torch.Tensor):
|
|
679
|
+
if arr.dtype != _torch.int64: arr = arr.to(_torch.int64)
|
|
680
|
+
arr = arr.contiguous()
|
|
681
|
+
return ffi.cast("uint64_t*", arr.data_ptr()), arr, arr.numel()
|
|
682
|
+
raise TypeError("Unsupported")
|
|
683
|
+
|
|
684
|
+
i_ptr_u64, _i_keep, i_count = ensure_u64_ptr(indices)
|
|
685
|
+
|
|
686
|
+
name_bytes = name.encode('utf-8')
|
|
687
|
+
shape_array = np.array(shape, dtype=np.uint64)
|
|
688
|
+
shape_ptr = ffi.cast("uint64_t*", shape_array.ctypes.data)
|
|
689
|
+
dtype_bytes = v_dtype.encode('utf-8')
|
|
690
|
+
|
|
691
|
+
status = lib.ztensor_writer_add_sparse_coo(
|
|
692
|
+
self._ptr,
|
|
693
|
+
name_bytes,
|
|
694
|
+
shape_ptr, len(shape),
|
|
695
|
+
dtype_bytes,
|
|
696
|
+
v_ptr, v_len,
|
|
697
|
+
i_ptr_u64, i_count
|
|
698
|
+
)
|
|
699
|
+
_check_status(status, f"add_sparse_coo: {name}")
|
|
700
|
+
|
|
701
|
+
def finalize(self):
|
|
702
|
+
"""Finalizes the zTensor file, writing the metadata index."""
|
|
703
|
+
if not self._ptr: raise ZTensorError("Writer is already closed or finalized.")
|
|
704
|
+
status = lib.ztensor_writer_finalize(self._ptr)
|
|
705
|
+
self._ptr = None
|
|
706
|
+
self._finalized = True
|
|
707
|
+
_check_status(status, "finalize")
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
__all__ = ["Reader", "Writer", "TensorMetadata", "ZTensorError"]
|