ztensor 0.1.0__py3-none-win32.whl → 0.1.2__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ztensor might be problematic. Click here for more details.
- ztensor/__init__.py +128 -42
- ztensor/ztensor/ztensor.dll +0 -0
- {ztensor-0.1.0.dist-info → ztensor-0.1.2.dist-info}/METADATA +2 -1
- ztensor-0.1.2.dist-info/RECORD +8 -0
- ztensor-0.1.0.dist-info/RECORD +0 -8
- {ztensor-0.1.0.dist-info → ztensor-0.1.2.dist-info}/WHEEL +0 -0
- {ztensor-0.1.0.dist-info → ztensor-0.1.2.dist-info}/licenses/LICENSE +0 -0
ztensor/__init__.py
CHANGED
|
@@ -1,6 +1,23 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
from .ztensor import ffi, lib
|
|
3
3
|
|
|
4
|
+
# --- Optional PyTorch Import ---
|
|
5
|
+
try:
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
TORCH_AVAILABLE = True
|
|
9
|
+
except ImportError:
|
|
10
|
+
TORCH_AVAILABLE = False
|
|
11
|
+
|
|
12
|
+
# --- Optional ml_dtypes for bfloat16 in NumPy ---
|
|
13
|
+
try:
|
|
14
|
+
from ml_dtypes import bfloat16 as np_bfloat16
|
|
15
|
+
|
|
16
|
+
ML_DTYPES_AVAILABLE = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
np_bfloat16 = None
|
|
19
|
+
ML_DTYPES_AVAILABLE = False
|
|
20
|
+
|
|
4
21
|
|
|
5
22
|
# --- Pythonic Wrapper ---
|
|
6
23
|
class ZTensorError(Exception):
|
|
@@ -11,7 +28,6 @@ class ZTensorError(Exception):
|
|
|
11
28
|
# A custom ndarray subclass to safely manage the lifetime of the CFFI pointer.
|
|
12
29
|
class _ZTensorView(np.ndarray):
|
|
13
30
|
def __new__(cls, buffer, dtype, shape, view_ptr):
|
|
14
|
-
# Create an array from the buffer, reshape it, and cast it to our custom type.
|
|
15
31
|
obj = np.frombuffer(buffer, dtype=dtype).reshape(shape).view(cls)
|
|
16
32
|
# Attach the object that owns the memory to an attribute.
|
|
17
33
|
obj._owner = view_ptr
|
|
@@ -44,23 +60,36 @@ def _check_status(status, func_name=""):
|
|
|
44
60
|
raise ZTensorError(f"Error in {func_name}: {_get_last_error()}")
|
|
45
61
|
|
|
46
62
|
|
|
47
|
-
# Type Mappings
|
|
63
|
+
# --- Type Mappings ---
|
|
64
|
+
# NumPy Mappings
|
|
48
65
|
DTYPE_NP_TO_ZT = {
|
|
49
|
-
np.dtype('float64'): 'float64', np.dtype('float32'): 'float32',
|
|
66
|
+
np.dtype('float64'): 'float64', np.dtype('float32'): 'float32', np.dtype('float16'): 'float16',
|
|
50
67
|
np.dtype('int64'): 'int64', np.dtype('int32'): 'int32',
|
|
51
68
|
np.dtype('int16'): 'int16', np.dtype('int8'): 'int8',
|
|
52
69
|
np.dtype('uint64'): 'uint64', np.dtype('uint32'): 'uint32',
|
|
53
70
|
np.dtype('uint16'): 'uint16', np.dtype('uint8'): 'uint8',
|
|
54
71
|
np.dtype('bool'): 'bool',
|
|
55
72
|
}
|
|
73
|
+
if ML_DTYPES_AVAILABLE:
|
|
74
|
+
DTYPE_NP_TO_ZT[np.dtype(np_bfloat16)] = 'bfloat16'
|
|
56
75
|
DTYPE_ZT_TO_NP = {v: k for k, v in DTYPE_NP_TO_ZT.items()}
|
|
57
76
|
|
|
77
|
+
# PyTorch Mappings (if available)
|
|
78
|
+
if TORCH_AVAILABLE:
|
|
79
|
+
DTYPE_TORCH_TO_ZT = {
|
|
80
|
+
torch.float64: 'float64', torch.float32: 'float32', torch.float16: 'float16',
|
|
81
|
+
torch.bfloat16: 'bfloat16',
|
|
82
|
+
torch.int64: 'int64', torch.int32: 'int32',
|
|
83
|
+
torch.int16: 'int16', torch.int8: 'int8',
|
|
84
|
+
torch.uint8: 'uint8', torch.bool: 'bool',
|
|
85
|
+
}
|
|
86
|
+
DTYPE_ZT_TO_TORCH = {v: k for k, v in DTYPE_TORCH_TO_ZT.items()}
|
|
87
|
+
|
|
58
88
|
|
|
59
89
|
class TensorMetadata:
|
|
60
90
|
"""A Pythonic wrapper around the CTensorMetadata pointer."""
|
|
61
91
|
|
|
62
92
|
def __init__(self, meta_ptr):
|
|
63
|
-
# The pointer is now automatically garbage collected by CFFI when this object dies.
|
|
64
93
|
self._ptr = ffi.gc(meta_ptr, lib.ztensor_metadata_free)
|
|
65
94
|
_check_ptr(self._ptr, "TensorMetadata constructor")
|
|
66
95
|
self._name = None
|
|
@@ -72,7 +101,6 @@ class TensorMetadata:
|
|
|
72
101
|
if self._name is None:
|
|
73
102
|
name_ptr = lib.ztensor_metadata_get_name(self._ptr)
|
|
74
103
|
_check_ptr(name_ptr, "get_name")
|
|
75
|
-
# ffi.string creates a copy, so we must free the Rust-allocated original.
|
|
76
104
|
self._name = ffi.string(name_ptr).decode('utf-8')
|
|
77
105
|
lib.ztensor_free_string(name_ptr)
|
|
78
106
|
return self._name
|
|
@@ -82,7 +110,6 @@ class TensorMetadata:
|
|
|
82
110
|
if self._dtype_str is None:
|
|
83
111
|
dtype_ptr = lib.ztensor_metadata_get_dtype_str(self._ptr)
|
|
84
112
|
_check_ptr(dtype_ptr, "get_dtype_str")
|
|
85
|
-
# ffi.string creates a copy, so we must free the Rust-allocated original.
|
|
86
113
|
self._dtype_str = ffi.string(dtype_ptr).decode('utf-8')
|
|
87
114
|
lib.ztensor_free_string(dtype_ptr)
|
|
88
115
|
return self._dtype_str
|
|
@@ -90,9 +117,17 @@ class TensorMetadata:
|
|
|
90
117
|
@property
|
|
91
118
|
def dtype(self):
|
|
92
119
|
"""Returns the numpy dtype for this tensor."""
|
|
93
|
-
|
|
120
|
+
dtype_str = self.dtype_str
|
|
121
|
+
dt = DTYPE_ZT_TO_NP.get(dtype_str)
|
|
122
|
+
if dt is None:
|
|
123
|
+
if dtype_str == 'bfloat16':
|
|
124
|
+
raise ZTensorError(
|
|
125
|
+
"Cannot read 'bfloat16' tensor as NumPy array because the 'ml_dtypes' "
|
|
126
|
+
"package is not installed. Please install it to proceed."
|
|
127
|
+
)
|
|
128
|
+
raise ZTensorError(f"Unsupported or unknown dtype string '{dtype_str}' found in tensor metadata.")
|
|
129
|
+
return dt
|
|
94
130
|
|
|
95
|
-
# RE-ENABLED: This property now works because the underlying FFI functions are available.
|
|
96
131
|
@property
|
|
97
132
|
def shape(self):
|
|
98
133
|
if self._shape is None:
|
|
@@ -101,7 +136,6 @@ class TensorMetadata:
|
|
|
101
136
|
shape_data_ptr = lib.ztensor_metadata_get_shape_data(self._ptr)
|
|
102
137
|
_check_ptr(shape_data_ptr, "get_shape_data")
|
|
103
138
|
self._shape = tuple(shape_data_ptr[i] for i in range(shape_len))
|
|
104
|
-
# Free the array that was allocated on the Rust side.
|
|
105
139
|
lib.ztensor_free_u64_array(shape_data_ptr, shape_len)
|
|
106
140
|
else:
|
|
107
141
|
self._shape = tuple()
|
|
@@ -115,15 +149,12 @@ class Reader:
|
|
|
115
149
|
path_bytes = file_path.encode('utf-8')
|
|
116
150
|
ptr = lib.ztensor_reader_open(path_bytes)
|
|
117
151
|
_check_ptr(ptr, f"Reader open: {file_path}")
|
|
118
|
-
# The pointer is automatically garbage collected by CFFI.
|
|
119
152
|
self._ptr = ffi.gc(ptr, lib.ztensor_reader_free)
|
|
120
153
|
|
|
121
154
|
def __enter__(self):
|
|
122
155
|
return self
|
|
123
156
|
|
|
124
157
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
125
|
-
# CFFI's garbage collector handles freeing the reader pointer automatically.
|
|
126
|
-
# No explicit free is needed here, simplifying the context manager.
|
|
127
158
|
self._ptr = None
|
|
128
159
|
|
|
129
160
|
def get_metadata(self, name: str) -> TensorMetadata:
|
|
@@ -134,8 +165,21 @@ class Reader:
|
|
|
134
165
|
_check_ptr(meta_ptr, f"get_metadata: {name}")
|
|
135
166
|
return TensorMetadata(meta_ptr)
|
|
136
167
|
|
|
137
|
-
def read_tensor(self, name: str
|
|
138
|
-
"""
|
|
168
|
+
def read_tensor(self, name: str, to: str = 'numpy'):
|
|
169
|
+
"""
|
|
170
|
+
Reads a tensor by name and returns it as a NumPy array or PyTorch tensor.
|
|
171
|
+
This is a zero-copy operation for both formats (for CPU tensors).
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
name (str): The name of the tensor to read.
|
|
175
|
+
to (str): The desired output format. Either 'numpy' (default) or 'torch'.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
np.ndarray or torch.Tensor: The tensor data.
|
|
179
|
+
"""
|
|
180
|
+
if to not in ['numpy', 'torch']:
|
|
181
|
+
raise ValueError(f"Unsupported format: '{to}'. Choose 'numpy' or 'torch'.")
|
|
182
|
+
|
|
139
183
|
metadata = self.get_metadata(name)
|
|
140
184
|
view_ptr = lib.ztensor_reader_read_tensor_view(self._ptr, metadata._ptr)
|
|
141
185
|
_check_ptr(view_ptr, f"read_tensor: {name}")
|
|
@@ -143,15 +187,37 @@ class Reader:
|
|
|
143
187
|
# Let CFFI manage the lifetime of the view pointer.
|
|
144
188
|
view_ptr = ffi.gc(view_ptr, lib.ztensor_free_tensor_view)
|
|
145
189
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
190
|
+
if to == 'numpy':
|
|
191
|
+
# Use the custom _ZTensorView to safely manage the FFI pointer lifetime.
|
|
192
|
+
return _ZTensorView(
|
|
193
|
+
buffer=ffi.buffer(view_ptr.data, view_ptr.len),
|
|
194
|
+
dtype=metadata.dtype, # This property raises on unsupported dtypes
|
|
195
|
+
shape=metadata.shape,
|
|
196
|
+
view_ptr=view_ptr
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
elif to == 'torch':
|
|
200
|
+
if not TORCH_AVAILABLE:
|
|
201
|
+
raise ZTensorError("PyTorch is not installed. Cannot return a torch tensor.")
|
|
202
|
+
|
|
203
|
+
# Get the corresponding torch dtype, raising if not supported.
|
|
204
|
+
torch_dtype = DTYPE_ZT_TO_TORCH.get(metadata.dtype_str)
|
|
205
|
+
if torch_dtype is None:
|
|
206
|
+
raise ZTensorError(
|
|
207
|
+
f"Cannot read tensor '{name}' as a PyTorch tensor. "
|
|
208
|
+
f"The dtype '{metadata.dtype_str}' is not supported by PyTorch."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Create a tensor directly from the buffer to avoid numpy conversion issues.
|
|
212
|
+
buffer = ffi.buffer(view_ptr.data, view_ptr.len)
|
|
213
|
+
torch_tensor = torch.frombuffer(buffer, dtype=torch_dtype).reshape(metadata.shape)
|
|
153
214
|
|
|
154
|
-
|
|
215
|
+
# CRITICAL: Attach the memory owner to the tensor to manage its lifetime.
|
|
216
|
+
# This ensures the Rust memory (held by view_ptr) is not freed while the
|
|
217
|
+
# torch tensor is still in use.
|
|
218
|
+
torch_tensor._owner = view_ptr
|
|
219
|
+
|
|
220
|
+
return torch_tensor
|
|
155
221
|
|
|
156
222
|
|
|
157
223
|
class Writer:
|
|
@@ -161,8 +227,6 @@ class Writer:
|
|
|
161
227
|
path_bytes = file_path.encode('utf-8')
|
|
162
228
|
ptr = lib.ztensor_writer_create(path_bytes)
|
|
163
229
|
_check_ptr(ptr, f"Writer create: {file_path}")
|
|
164
|
-
# The pointer is consumed by finalize, so we don't use ffi.gc here.
|
|
165
|
-
# The writer should be freed via finalize or ztensor_writer_free if finalize fails.
|
|
166
230
|
self._ptr = ptr
|
|
167
231
|
self._finalized = False
|
|
168
232
|
|
|
@@ -170,45 +234,67 @@ class Writer:
|
|
|
170
234
|
return self
|
|
171
235
|
|
|
172
236
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
173
|
-
# Automatically finalize on exit if not already done and no error occurred.
|
|
174
237
|
if self._ptr and not self._finalized:
|
|
175
238
|
if exc_type is None:
|
|
176
239
|
self.finalize()
|
|
177
240
|
else:
|
|
178
|
-
# If an error occurred, don't finalize, just free the writer to prevent leaks.
|
|
179
241
|
lib.ztensor_writer_free(self._ptr)
|
|
180
242
|
self._ptr = None
|
|
181
243
|
|
|
182
|
-
def add_tensor(self, name: str, tensor
|
|
183
|
-
"""
|
|
184
|
-
|
|
244
|
+
def add_tensor(self, name: str, tensor):
|
|
245
|
+
"""
|
|
246
|
+
Adds a NumPy or PyTorch tensor to the file (zero-copy).
|
|
247
|
+
Supports float16 and bfloat16 types.
|
|
185
248
|
|
|
186
|
-
|
|
187
|
-
|
|
249
|
+
Args:
|
|
250
|
+
name (str): The name of the tensor to add.
|
|
251
|
+
tensor (np.ndarray or torch.Tensor): The tensor data to write.
|
|
252
|
+
"""
|
|
253
|
+
if not self._ptr: raise ZTensorError("Writer is closed or finalized.")
|
|
188
254
|
|
|
189
|
-
|
|
190
|
-
|
|
255
|
+
# --- Polymorphic tensor handling ---
|
|
256
|
+
if isinstance(tensor, np.ndarray):
|
|
257
|
+
tensor = np.ascontiguousarray(tensor)
|
|
258
|
+
shape = tensor.shape
|
|
259
|
+
dtype_str = DTYPE_NP_TO_ZT.get(tensor.dtype)
|
|
260
|
+
data_ptr = ffi.cast("unsigned char*", tensor.ctypes.data)
|
|
261
|
+
nbytes = tensor.nbytes
|
|
262
|
+
|
|
263
|
+
elif TORCH_AVAILABLE and isinstance(tensor, torch.Tensor):
|
|
264
|
+
if tensor.is_cuda:
|
|
265
|
+
raise ZTensorError("Cannot write directly from a CUDA tensor. Copy to CPU first using .cpu().")
|
|
266
|
+
tensor = tensor.contiguous()
|
|
267
|
+
shape = tuple(tensor.shape)
|
|
268
|
+
dtype_str = DTYPE_TORCH_TO_ZT.get(tensor.dtype)
|
|
269
|
+
data_ptr = ffi.cast("unsigned char*", tensor.data_ptr())
|
|
270
|
+
nbytes = tensor.numel() * tensor.element_size()
|
|
271
|
+
|
|
272
|
+
else:
|
|
273
|
+
supported = "np.ndarray" + (" or torch.Tensor" if TORCH_AVAILABLE else "")
|
|
274
|
+
raise TypeError(f"Unsupported tensor type: {type(tensor)}. Must be {supported}.")
|
|
191
275
|
|
|
192
|
-
dtype_str = DTYPE_NP_TO_ZT.get(tensor.dtype)
|
|
193
276
|
if not dtype_str:
|
|
194
|
-
|
|
195
|
-
|
|
277
|
+
msg = f"Unsupported dtype: {tensor.dtype}."
|
|
278
|
+
if 'bfloat16' in str(tensor.dtype) and not ML_DTYPES_AVAILABLE:
|
|
279
|
+
msg += " For NumPy bfloat16 support, please install the 'ml_dtypes' package."
|
|
280
|
+
raise ZTensorError(msg)
|
|
196
281
|
|
|
197
|
-
|
|
198
|
-
|
|
282
|
+
name_bytes = name.encode('utf-8')
|
|
283
|
+
shape_array = np.array(shape, dtype=np.uint64)
|
|
284
|
+
shape_ptr = ffi.cast("uint64_t*", shape_array.ctypes.data)
|
|
285
|
+
dtype_bytes = dtype_str.encode('utf-8')
|
|
199
286
|
|
|
200
287
|
status = lib.ztensor_writer_add_tensor(
|
|
201
|
-
self._ptr, name_bytes, shape_ptr, len(
|
|
202
|
-
dtype_bytes, data_ptr,
|
|
288
|
+
self._ptr, name_bytes, shape_ptr, len(shape),
|
|
289
|
+
dtype_bytes, data_ptr, nbytes
|
|
203
290
|
)
|
|
204
291
|
_check_status(status, f"add_tensor: {name}")
|
|
205
292
|
|
|
206
293
|
def finalize(self):
|
|
207
294
|
"""Finalizes the zTensor file, writing the metadata index."""
|
|
208
295
|
if not self._ptr: raise ZTensorError("Writer is already closed or finalized.")
|
|
209
|
-
|
|
210
296
|
status = lib.ztensor_writer_finalize(self._ptr)
|
|
211
|
-
self._ptr = None
|
|
297
|
+
self._ptr = None
|
|
212
298
|
self._finalized = True
|
|
213
299
|
_check_status(status, "finalize")
|
|
214
300
|
|
ztensor/ztensor/ztensor.dll
CHANGED
|
Binary file
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ztensor
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Classifier: Programming Language :: Rust
|
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
|
6
6
|
Classifier: License :: OSI Approved :: MIT License
|
|
@@ -9,6 +9,7 @@ Classifier: Intended Audience :: Developers
|
|
|
9
9
|
Classifier: Topic :: Scientific/Engineering
|
|
10
10
|
Requires-Dist: numpy
|
|
11
11
|
Requires-Dist: cffi
|
|
12
|
+
Requires-Dist: ml-dtypes
|
|
12
13
|
License-File: LICENSE
|
|
13
14
|
Summary: Python bindings for the zTensor library.
|
|
14
15
|
Author: In Gim <in.gim@yale.edu>
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
ztensor-0.1.2.dist-info/METADATA,sha256=4UODGfQ_vW4gGHCIunIBOFRA4aB_BrbymnQQNa87vi4,4530
|
|
2
|
+
ztensor-0.1.2.dist-info/WHEEL,sha256=v0atuDHRmBdo6voLOtYJ8NlQfECX5fZNHOXLlNqbhus,89
|
|
3
|
+
ztensor-0.1.2.dist-info/licenses/LICENSE,sha256=qxF7VFxBvMlfiDRJ5oXQuQYaloq0Tcbk95Pn0DFlnss,1084
|
|
4
|
+
ztensor/__init__.py,sha256=MHfJeuyHa9r1dJ-8ezqRX0wpR5EBwhjkMzg4HlnLFL4,11600
|
|
5
|
+
ztensor/ztensor/__init__.py,sha256=DDVvoEhcXithkluOJ4Dd7H6wIqKcxT6mm6vvPgrQMz4,138
|
|
6
|
+
ztensor/ztensor/ffi.py,sha256=5HqR7Szwsn6HmKARaYQqF0nJKEUWZsWpYD4ZiOaoWnk,2756
|
|
7
|
+
ztensor/ztensor/ztensor.dll,sha256=RfrdQFVVG9QGDXU9z4nGxviysq5BCfxrZXTrdQ7wJF4,871936
|
|
8
|
+
ztensor-0.1.2.dist-info/RECORD,,
|
ztensor-0.1.0.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
ztensor-0.1.0.dist-info/METADATA,sha256=4s1ZEgsJePgp3S96g56LTjgNeGZeopSGoFk6jvZkXLE,4505
|
|
2
|
-
ztensor-0.1.0.dist-info/WHEEL,sha256=v0atuDHRmBdo6voLOtYJ8NlQfECX5fZNHOXLlNqbhus,89
|
|
3
|
-
ztensor-0.1.0.dist-info/licenses/LICENSE,sha256=qxF7VFxBvMlfiDRJ5oXQuQYaloq0Tcbk95Pn0DFlnss,1084
|
|
4
|
-
ztensor/__init__.py,sha256=Blfw3ZTJQPAsFQgqAw2y2KY_Zwh5Do-hjkOMS_Fx5Xs,8619
|
|
5
|
-
ztensor/ztensor/__init__.py,sha256=DDVvoEhcXithkluOJ4Dd7H6wIqKcxT6mm6vvPgrQMz4,138
|
|
6
|
-
ztensor/ztensor/ffi.py,sha256=5HqR7Szwsn6HmKARaYQqF0nJKEUWZsWpYD4ZiOaoWnk,2756
|
|
7
|
-
ztensor/ztensor/ztensor.dll,sha256=a8TSYxK-0prNrBkjMsr02Dz8GB69qLdwAfq6DNRGP78,871936
|
|
8
|
-
ztensor-0.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|