ztensor 0.1.1__py3-none-win_amd64.whl → 0.1.2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ztensor might be problematic. Click here for more details.
- ztensor/__init__.py +127 -55
- ztensor/ztensor/ztensor.dll +0 -0
- {ztensor-0.1.1.dist-info → ztensor-0.1.2.dist-info}/METADATA +1 -1
- ztensor-0.1.2.dist-info/RECORD +8 -0
- ztensor-0.1.1.dist-info/RECORD +0 -8
- {ztensor-0.1.1.dist-info → ztensor-0.1.2.dist-info}/WHEEL +0 -0
- {ztensor-0.1.1.dist-info → ztensor-0.1.2.dist-info}/licenses/LICENSE +0 -0
ztensor/__init__.py
CHANGED
|
@@ -1,6 +1,22 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
from .ztensor import ffi, lib
|
|
3
|
-
|
|
3
|
+
|
|
4
|
+
# --- Optional PyTorch Import ---
|
|
5
|
+
try:
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
TORCH_AVAILABLE = True
|
|
9
|
+
except ImportError:
|
|
10
|
+
TORCH_AVAILABLE = False
|
|
11
|
+
|
|
12
|
+
# --- Optional ml_dtypes for bfloat16 in NumPy ---
|
|
13
|
+
try:
|
|
14
|
+
from ml_dtypes import bfloat16 as np_bfloat16
|
|
15
|
+
|
|
16
|
+
ML_DTYPES_AVAILABLE = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
np_bfloat16 = None
|
|
19
|
+
ML_DTYPES_AVAILABLE = False
|
|
4
20
|
|
|
5
21
|
|
|
6
22
|
# --- Pythonic Wrapper ---
|
|
@@ -12,7 +28,6 @@ class ZTensorError(Exception):
|
|
|
12
28
|
# A custom ndarray subclass to safely manage the lifetime of the CFFI pointer.
|
|
13
29
|
class _ZTensorView(np.ndarray):
|
|
14
30
|
def __new__(cls, buffer, dtype, shape, view_ptr):
|
|
15
|
-
# Create an array from the buffer, reshape it, and cast it to our custom type.
|
|
16
31
|
obj = np.frombuffer(buffer, dtype=dtype).reshape(shape).view(cls)
|
|
17
32
|
# Attach the object that owns the memory to an attribute.
|
|
18
33
|
obj._owner = view_ptr
|
|
@@ -45,26 +60,36 @@ def _check_status(status, func_name=""):
|
|
|
45
60
|
raise ZTensorError(f"Error in {func_name}: {_get_last_error()}")
|
|
46
61
|
|
|
47
62
|
|
|
48
|
-
# Type Mappings
|
|
63
|
+
# --- Type Mappings ---
|
|
64
|
+
# NumPy Mappings
|
|
49
65
|
DTYPE_NP_TO_ZT = {
|
|
50
|
-
np.dtype('float64'): 'float64', np.dtype('float32'): 'float32',
|
|
66
|
+
np.dtype('float64'): 'float64', np.dtype('float32'): 'float32', np.dtype('float16'): 'float16',
|
|
51
67
|
np.dtype('int64'): 'int64', np.dtype('int32'): 'int32',
|
|
52
68
|
np.dtype('int16'): 'int16', np.dtype('int8'): 'int8',
|
|
53
69
|
np.dtype('uint64'): 'uint64', np.dtype('uint32'): 'uint32',
|
|
54
70
|
np.dtype('uint16'): 'uint16', np.dtype('uint8'): 'uint8',
|
|
55
71
|
np.dtype('bool'): 'bool',
|
|
56
|
-
# ADDED: Mapping for bfloat16 to handle writing
|
|
57
|
-
np.dtype(bfloat16): 'bfloat16',
|
|
58
72
|
}
|
|
59
|
-
|
|
73
|
+
if ML_DTYPES_AVAILABLE:
|
|
74
|
+
DTYPE_NP_TO_ZT[np.dtype(np_bfloat16)] = 'bfloat16'
|
|
60
75
|
DTYPE_ZT_TO_NP = {v: k for k, v in DTYPE_NP_TO_ZT.items()}
|
|
61
76
|
|
|
77
|
+
# PyTorch Mappings (if available)
|
|
78
|
+
if TORCH_AVAILABLE:
|
|
79
|
+
DTYPE_TORCH_TO_ZT = {
|
|
80
|
+
torch.float64: 'float64', torch.float32: 'float32', torch.float16: 'float16',
|
|
81
|
+
torch.bfloat16: 'bfloat16',
|
|
82
|
+
torch.int64: 'int64', torch.int32: 'int32',
|
|
83
|
+
torch.int16: 'int16', torch.int8: 'int8',
|
|
84
|
+
torch.uint8: 'uint8', torch.bool: 'bool',
|
|
85
|
+
}
|
|
86
|
+
DTYPE_ZT_TO_TORCH = {v: k for k, v in DTYPE_TORCH_TO_ZT.items()}
|
|
87
|
+
|
|
62
88
|
|
|
63
89
|
class TensorMetadata:
|
|
64
90
|
"""A Pythonic wrapper around the CTensorMetadata pointer."""
|
|
65
91
|
|
|
66
92
|
def __init__(self, meta_ptr):
|
|
67
|
-
# The pointer is now automatically garbage collected by CFFI when this object dies.
|
|
68
93
|
self._ptr = ffi.gc(meta_ptr, lib.ztensor_metadata_free)
|
|
69
94
|
_check_ptr(self._ptr, "TensorMetadata constructor")
|
|
70
95
|
self._name = None
|
|
@@ -76,7 +101,6 @@ class TensorMetadata:
|
|
|
76
101
|
if self._name is None:
|
|
77
102
|
name_ptr = lib.ztensor_metadata_get_name(self._ptr)
|
|
78
103
|
_check_ptr(name_ptr, "get_name")
|
|
79
|
-
# ffi.string creates a copy, so we must free the Rust-allocated original.
|
|
80
104
|
self._name = ffi.string(name_ptr).decode('utf-8')
|
|
81
105
|
lib.ztensor_free_string(name_ptr)
|
|
82
106
|
return self._name
|
|
@@ -86,7 +110,6 @@ class TensorMetadata:
|
|
|
86
110
|
if self._dtype_str is None:
|
|
87
111
|
dtype_ptr = lib.ztensor_metadata_get_dtype_str(self._ptr)
|
|
88
112
|
_check_ptr(dtype_ptr, "get_dtype_str")
|
|
89
|
-
# ffi.string creates a copy, so we must free the Rust-allocated original.
|
|
90
113
|
self._dtype_str = ffi.string(dtype_ptr).decode('utf-8')
|
|
91
114
|
lib.ztensor_free_string(dtype_ptr)
|
|
92
115
|
return self._dtype_str
|
|
@@ -94,10 +117,17 @@ class TensorMetadata:
|
|
|
94
117
|
@property
|
|
95
118
|
def dtype(self):
|
|
96
119
|
"""Returns the numpy dtype for this tensor."""
|
|
97
|
-
|
|
98
|
-
|
|
120
|
+
dtype_str = self.dtype_str
|
|
121
|
+
dt = DTYPE_ZT_TO_NP.get(dtype_str)
|
|
122
|
+
if dt is None:
|
|
123
|
+
if dtype_str == 'bfloat16':
|
|
124
|
+
raise ZTensorError(
|
|
125
|
+
"Cannot read 'bfloat16' tensor as NumPy array because the 'ml_dtypes' "
|
|
126
|
+
"package is not installed. Please install it to proceed."
|
|
127
|
+
)
|
|
128
|
+
raise ZTensorError(f"Unsupported or unknown dtype string '{dtype_str}' found in tensor metadata.")
|
|
129
|
+
return dt
|
|
99
130
|
|
|
100
|
-
# RE-ENABLED: This property now works because the underlying FFI functions are available.
|
|
101
131
|
@property
|
|
102
132
|
def shape(self):
|
|
103
133
|
if self._shape is None:
|
|
@@ -106,7 +136,6 @@ class TensorMetadata:
|
|
|
106
136
|
shape_data_ptr = lib.ztensor_metadata_get_shape_data(self._ptr)
|
|
107
137
|
_check_ptr(shape_data_ptr, "get_shape_data")
|
|
108
138
|
self._shape = tuple(shape_data_ptr[i] for i in range(shape_len))
|
|
109
|
-
# Free the array that was allocated on the Rust side.
|
|
110
139
|
lib.ztensor_free_u64_array(shape_data_ptr, shape_len)
|
|
111
140
|
else:
|
|
112
141
|
self._shape = tuple()
|
|
@@ -120,15 +149,12 @@ class Reader:
|
|
|
120
149
|
path_bytes = file_path.encode('utf-8')
|
|
121
150
|
ptr = lib.ztensor_reader_open(path_bytes)
|
|
122
151
|
_check_ptr(ptr, f"Reader open: {file_path}")
|
|
123
|
-
# The pointer is automatically garbage collected by CFFI.
|
|
124
152
|
self._ptr = ffi.gc(ptr, lib.ztensor_reader_free)
|
|
125
153
|
|
|
126
154
|
def __enter__(self):
|
|
127
155
|
return self
|
|
128
156
|
|
|
129
157
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
130
|
-
# CFFI's garbage collector handles freeing the reader pointer automatically.
|
|
131
|
-
# No explicit free is needed here, simplifying the context manager.
|
|
132
158
|
self._ptr = None
|
|
133
159
|
|
|
134
160
|
def get_metadata(self, name: str) -> TensorMetadata:
|
|
@@ -139,32 +165,59 @@ class Reader:
|
|
|
139
165
|
_check_ptr(meta_ptr, f"get_metadata: {name}")
|
|
140
166
|
return TensorMetadata(meta_ptr)
|
|
141
167
|
|
|
142
|
-
def read_tensor(self, name: str
|
|
143
|
-
"""
|
|
144
|
-
|
|
168
|
+
def read_tensor(self, name: str, to: str = 'numpy'):
|
|
169
|
+
"""
|
|
170
|
+
Reads a tensor by name and returns it as a NumPy array or PyTorch tensor.
|
|
171
|
+
This is a zero-copy operation for both formats (for CPU tensors).
|
|
145
172
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
raise ZTensorError(f"Unsupported ztensor dtype: '{metadata.dtype_str}'")
|
|
173
|
+
Args:
|
|
174
|
+
name (str): The name of the tensor to read.
|
|
175
|
+
to (str): The desired output format. Either 'numpy' (default) or 'torch'.
|
|
150
176
|
|
|
177
|
+
Returns:
|
|
178
|
+
np.ndarray or torch.Tensor: The tensor data.
|
|
179
|
+
"""
|
|
180
|
+
if to not in ['numpy', 'torch']:
|
|
181
|
+
raise ValueError(f"Unsupported format: '{to}'. Choose 'numpy' or 'torch'.")
|
|
182
|
+
|
|
183
|
+
metadata = self.get_metadata(name)
|
|
151
184
|
view_ptr = lib.ztensor_reader_read_tensor_view(self._ptr, metadata._ptr)
|
|
152
185
|
_check_ptr(view_ptr, f"read_tensor: {name}")
|
|
153
186
|
|
|
154
187
|
# Let CFFI manage the lifetime of the view pointer.
|
|
155
188
|
view_ptr = ffi.gc(view_ptr, lib.ztensor_free_tensor_view)
|
|
156
189
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
190
|
+
if to == 'numpy':
|
|
191
|
+
# Use the custom _ZTensorView to safely manage the FFI pointer lifetime.
|
|
192
|
+
return _ZTensorView(
|
|
193
|
+
buffer=ffi.buffer(view_ptr.data, view_ptr.len),
|
|
194
|
+
dtype=metadata.dtype, # This property raises on unsupported dtypes
|
|
195
|
+
shape=metadata.shape,
|
|
196
|
+
view_ptr=view_ptr
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
elif to == 'torch':
|
|
200
|
+
if not TORCH_AVAILABLE:
|
|
201
|
+
raise ZTensorError("PyTorch is not installed. Cannot return a torch tensor.")
|
|
202
|
+
|
|
203
|
+
# Get the corresponding torch dtype, raising if not supported.
|
|
204
|
+
torch_dtype = DTYPE_ZT_TO_TORCH.get(metadata.dtype_str)
|
|
205
|
+
if torch_dtype is None:
|
|
206
|
+
raise ZTensorError(
|
|
207
|
+
f"Cannot read tensor '{name}' as a PyTorch tensor. "
|
|
208
|
+
f"The dtype '{metadata.dtype_str}' is not supported by PyTorch."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Create a tensor directly from the buffer to avoid numpy conversion issues.
|
|
212
|
+
buffer = ffi.buffer(view_ptr.data, view_ptr.len)
|
|
213
|
+
torch_tensor = torch.frombuffer(buffer, dtype=torch_dtype).reshape(metadata.shape)
|
|
166
214
|
|
|
167
|
-
|
|
215
|
+
# CRITICAL: Attach the memory owner to the tensor to manage its lifetime.
|
|
216
|
+
# This ensures the Rust memory (held by view_ptr) is not freed while the
|
|
217
|
+
# torch tensor is still in use.
|
|
218
|
+
torch_tensor._owner = view_ptr
|
|
219
|
+
|
|
220
|
+
return torch_tensor
|
|
168
221
|
|
|
169
222
|
|
|
170
223
|
class Writer:
|
|
@@ -174,8 +227,6 @@ class Writer:
|
|
|
174
227
|
path_bytes = file_path.encode('utf-8')
|
|
175
228
|
ptr = lib.ztensor_writer_create(path_bytes)
|
|
176
229
|
_check_ptr(ptr, f"Writer create: {file_path}")
|
|
177
|
-
# The pointer is consumed by finalize, so we don't use ffi.gc here.
|
|
178
|
-
# The writer should be freed via finalize or ztensor_writer_free if finalize fails.
|
|
179
230
|
self._ptr = ptr
|
|
180
231
|
self._finalized = False
|
|
181
232
|
|
|
@@ -183,46 +234,67 @@ class Writer:
|
|
|
183
234
|
return self
|
|
184
235
|
|
|
185
236
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
186
|
-
# Automatically finalize on exit if not already done and no error occurred.
|
|
187
237
|
if self._ptr and not self._finalized:
|
|
188
238
|
if exc_type is None:
|
|
189
239
|
self.finalize()
|
|
190
240
|
else:
|
|
191
|
-
# If an error occurred, don't finalize, just free the writer to prevent leaks.
|
|
192
241
|
lib.ztensor_writer_free(self._ptr)
|
|
193
242
|
self._ptr = None
|
|
194
243
|
|
|
195
|
-
def add_tensor(self, name: str, tensor
|
|
196
|
-
"""
|
|
197
|
-
|
|
244
|
+
def add_tensor(self, name: str, tensor):
|
|
245
|
+
"""
|
|
246
|
+
Adds a NumPy or PyTorch tensor to the file (zero-copy).
|
|
247
|
+
Supports float16 and bfloat16 types.
|
|
198
248
|
|
|
199
|
-
|
|
200
|
-
|
|
249
|
+
Args:
|
|
250
|
+
name (str): The name of the tensor to add.
|
|
251
|
+
tensor (np.ndarray or torch.Tensor): The tensor data to write.
|
|
252
|
+
"""
|
|
253
|
+
if not self._ptr: raise ZTensorError("Writer is closed or finalized.")
|
|
201
254
|
|
|
202
|
-
|
|
203
|
-
|
|
255
|
+
# --- Polymorphic tensor handling ---
|
|
256
|
+
if isinstance(tensor, np.ndarray):
|
|
257
|
+
tensor = np.ascontiguousarray(tensor)
|
|
258
|
+
shape = tensor.shape
|
|
259
|
+
dtype_str = DTYPE_NP_TO_ZT.get(tensor.dtype)
|
|
260
|
+
data_ptr = ffi.cast("unsigned char*", tensor.ctypes.data)
|
|
261
|
+
nbytes = tensor.nbytes
|
|
262
|
+
|
|
263
|
+
elif TORCH_AVAILABLE and isinstance(tensor, torch.Tensor):
|
|
264
|
+
if tensor.is_cuda:
|
|
265
|
+
raise ZTensorError("Cannot write directly from a CUDA tensor. Copy to CPU first using .cpu().")
|
|
266
|
+
tensor = tensor.contiguous()
|
|
267
|
+
shape = tuple(tensor.shape)
|
|
268
|
+
dtype_str = DTYPE_TORCH_TO_ZT.get(tensor.dtype)
|
|
269
|
+
data_ptr = ffi.cast("unsigned char*", tensor.data_ptr())
|
|
270
|
+
nbytes = tensor.numel() * tensor.element_size()
|
|
271
|
+
|
|
272
|
+
else:
|
|
273
|
+
supported = "np.ndarray" + (" or torch.Tensor" if TORCH_AVAILABLE else "")
|
|
274
|
+
raise TypeError(f"Unsupported tensor type: {type(tensor)}. Must be {supported}.")
|
|
204
275
|
|
|
205
|
-
# The updated DTYPE_NP_TO_ZT will now correctly handle bfloat16 tensors.
|
|
206
|
-
dtype_str = DTYPE_NP_TO_ZT.get(tensor.dtype)
|
|
207
276
|
if not dtype_str:
|
|
208
|
-
|
|
209
|
-
|
|
277
|
+
msg = f"Unsupported dtype: {tensor.dtype}."
|
|
278
|
+
if 'bfloat16' in str(tensor.dtype) and not ML_DTYPES_AVAILABLE:
|
|
279
|
+
msg += " For NumPy bfloat16 support, please install the 'ml_dtypes' package."
|
|
280
|
+
raise ZTensorError(msg)
|
|
210
281
|
|
|
211
|
-
|
|
212
|
-
|
|
282
|
+
name_bytes = name.encode('utf-8')
|
|
283
|
+
shape_array = np.array(shape, dtype=np.uint64)
|
|
284
|
+
shape_ptr = ffi.cast("uint64_t*", shape_array.ctypes.data)
|
|
285
|
+
dtype_bytes = dtype_str.encode('utf-8')
|
|
213
286
|
|
|
214
287
|
status = lib.ztensor_writer_add_tensor(
|
|
215
|
-
self._ptr, name_bytes, shape_ptr, len(
|
|
216
|
-
dtype_bytes, data_ptr,
|
|
288
|
+
self._ptr, name_bytes, shape_ptr, len(shape),
|
|
289
|
+
dtype_bytes, data_ptr, nbytes
|
|
217
290
|
)
|
|
218
291
|
_check_status(status, f"add_tensor: {name}")
|
|
219
292
|
|
|
220
293
|
def finalize(self):
|
|
221
294
|
"""Finalizes the zTensor file, writing the metadata index."""
|
|
222
295
|
if not self._ptr: raise ZTensorError("Writer is already closed or finalized.")
|
|
223
|
-
|
|
224
296
|
status = lib.ztensor_writer_finalize(self._ptr)
|
|
225
|
-
self._ptr = None
|
|
297
|
+
self._ptr = None
|
|
226
298
|
self._finalized = True
|
|
227
299
|
_check_status(status, "finalize")
|
|
228
300
|
|
ztensor/ztensor/ztensor.dll
CHANGED
|
Binary file
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
ztensor-0.1.2.dist-info/METADATA,sha256=4UODGfQ_vW4gGHCIunIBOFRA4aB_BrbymnQQNa87vi4,4530
|
|
2
|
+
ztensor-0.1.2.dist-info/WHEEL,sha256=S7OxZtuPihI-XN3jZuq2sqVMik-O-jyGLiThWItpyfk,93
|
|
3
|
+
ztensor-0.1.2.dist-info/licenses/LICENSE,sha256=qxF7VFxBvMlfiDRJ5oXQuQYaloq0Tcbk95Pn0DFlnss,1084
|
|
4
|
+
ztensor/__init__.py,sha256=MHfJeuyHa9r1dJ-8ezqRX0wpR5EBwhjkMzg4HlnLFL4,11600
|
|
5
|
+
ztensor/ztensor/__init__.py,sha256=DDVvoEhcXithkluOJ4Dd7H6wIqKcxT6mm6vvPgrQMz4,138
|
|
6
|
+
ztensor/ztensor/ffi.py,sha256=5HqR7Szwsn6HmKARaYQqF0nJKEUWZsWpYD4ZiOaoWnk,2756
|
|
7
|
+
ztensor/ztensor/ztensor.dll,sha256=M4b64ywvca1ZaMHGSnNMKAwS8Hkwse88HzRrkmZo5xg,1011712
|
|
8
|
+
ztensor-0.1.2.dist-info/RECORD,,
|
ztensor-0.1.1.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
ztensor-0.1.1.dist-info/METADATA,sha256=hpFnvbYe6JDNWzbgozx-V_DwwZF2e6dxL5QDeEa9uKs,4530
|
|
2
|
-
ztensor-0.1.1.dist-info/WHEEL,sha256=S7OxZtuPihI-XN3jZuq2sqVMik-O-jyGLiThWItpyfk,93
|
|
3
|
-
ztensor-0.1.1.dist-info/licenses/LICENSE,sha256=qxF7VFxBvMlfiDRJ5oXQuQYaloq0Tcbk95Pn0DFlnss,1084
|
|
4
|
-
ztensor/__init__.py,sha256=3xHeBl9o9RyUriUlgPr3UwBZC3h15sgIGElnlJ2kIRY,9316
|
|
5
|
-
ztensor/ztensor/__init__.py,sha256=DDVvoEhcXithkluOJ4Dd7H6wIqKcxT6mm6vvPgrQMz4,138
|
|
6
|
-
ztensor/ztensor/ffi.py,sha256=5HqR7Szwsn6HmKARaYQqF0nJKEUWZsWpYD4ZiOaoWnk,2756
|
|
7
|
-
ztensor/ztensor/ztensor.dll,sha256=Ytu-ECNwrdJvohdegnJcotJfXnZTufYKNczXns7i22Y,1011712
|
|
8
|
-
ztensor-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|