ztensor 0.1.1__py3-none-musllinux_1_2_i686.whl → 0.1.2__py3-none-musllinux_1_2_i686.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ztensor might be problematic. Click here for more details.

ztensor/__init__.py CHANGED
@@ -1,6 +1,22 @@
1
1
  import numpy as np
2
2
  from .ztensor import ffi, lib
3
- from ml_dtypes import bfloat16
3
+
4
+ # --- Optional PyTorch Import ---
5
+ try:
6
+ import torch
7
+
8
+ TORCH_AVAILABLE = True
9
+ except ImportError:
10
+ TORCH_AVAILABLE = False
11
+
12
+ # --- Optional ml_dtypes for bfloat16 in NumPy ---
13
+ try:
14
+ from ml_dtypes import bfloat16 as np_bfloat16
15
+
16
+ ML_DTYPES_AVAILABLE = True
17
+ except ImportError:
18
+ np_bfloat16 = None
19
+ ML_DTYPES_AVAILABLE = False
4
20
 
5
21
 
6
22
  # --- Pythonic Wrapper ---
@@ -12,7 +28,6 @@ class ZTensorError(Exception):
12
28
  # A custom ndarray subclass to safely manage the lifetime of the CFFI pointer.
13
29
  class _ZTensorView(np.ndarray):
14
30
  def __new__(cls, buffer, dtype, shape, view_ptr):
15
- # Create an array from the buffer, reshape it, and cast it to our custom type.
16
31
  obj = np.frombuffer(buffer, dtype=dtype).reshape(shape).view(cls)
17
32
  # Attach the object that owns the memory to an attribute.
18
33
  obj._owner = view_ptr
@@ -45,26 +60,36 @@ def _check_status(status, func_name=""):
45
60
  raise ZTensorError(f"Error in {func_name}: {_get_last_error()}")
46
61
 
47
62
 
48
- # Type Mappings between NumPy and ztensor
63
+ # --- Type Mappings ---
64
+ # NumPy Mappings
49
65
  DTYPE_NP_TO_ZT = {
50
- np.dtype('float64'): 'float64', np.dtype('float32'): 'float32',
66
+ np.dtype('float64'): 'float64', np.dtype('float32'): 'float32', np.dtype('float16'): 'float16',
51
67
  np.dtype('int64'): 'int64', np.dtype('int32'): 'int32',
52
68
  np.dtype('int16'): 'int16', np.dtype('int8'): 'int8',
53
69
  np.dtype('uint64'): 'uint64', np.dtype('uint32'): 'uint32',
54
70
  np.dtype('uint16'): 'uint16', np.dtype('uint8'): 'uint8',
55
71
  np.dtype('bool'): 'bool',
56
- # ADDED: Mapping for bfloat16 to handle writing
57
- np.dtype(bfloat16): 'bfloat16',
58
72
  }
59
- # MODIFIED: Re-construct the reverse mapping
73
+ if ML_DTYPES_AVAILABLE:
74
+ DTYPE_NP_TO_ZT[np.dtype(np_bfloat16)] = 'bfloat16'
60
75
  DTYPE_ZT_TO_NP = {v: k for k, v in DTYPE_NP_TO_ZT.items()}
61
76
 
77
+ # PyTorch Mappings (if available)
78
+ if TORCH_AVAILABLE:
79
+ DTYPE_TORCH_TO_ZT = {
80
+ torch.float64: 'float64', torch.float32: 'float32', torch.float16: 'float16',
81
+ torch.bfloat16: 'bfloat16',
82
+ torch.int64: 'int64', torch.int32: 'int32',
83
+ torch.int16: 'int16', torch.int8: 'int8',
84
+ torch.uint8: 'uint8', torch.bool: 'bool',
85
+ }
86
+ DTYPE_ZT_TO_TORCH = {v: k for k, v in DTYPE_TORCH_TO_ZT.items()}
87
+
62
88
 
63
89
  class TensorMetadata:
64
90
  """A Pythonic wrapper around the CTensorMetadata pointer."""
65
91
 
66
92
  def __init__(self, meta_ptr):
67
- # The pointer is now automatically garbage collected by CFFI when this object dies.
68
93
  self._ptr = ffi.gc(meta_ptr, lib.ztensor_metadata_free)
69
94
  _check_ptr(self._ptr, "TensorMetadata constructor")
70
95
  self._name = None
@@ -76,7 +101,6 @@ class TensorMetadata:
76
101
  if self._name is None:
77
102
  name_ptr = lib.ztensor_metadata_get_name(self._ptr)
78
103
  _check_ptr(name_ptr, "get_name")
79
- # ffi.string creates a copy, so we must free the Rust-allocated original.
80
104
  self._name = ffi.string(name_ptr).decode('utf-8')
81
105
  lib.ztensor_free_string(name_ptr)
82
106
  return self._name
@@ -86,7 +110,6 @@ class TensorMetadata:
86
110
  if self._dtype_str is None:
87
111
  dtype_ptr = lib.ztensor_metadata_get_dtype_str(self._ptr)
88
112
  _check_ptr(dtype_ptr, "get_dtype_str")
89
- # ffi.string creates a copy, so we must free the Rust-allocated original.
90
113
  self._dtype_str = ffi.string(dtype_ptr).decode('utf-8')
91
114
  lib.ztensor_free_string(dtype_ptr)
92
115
  return self._dtype_str
@@ -94,10 +117,17 @@ class TensorMetadata:
94
117
  @property
95
118
  def dtype(self):
96
119
  """Returns the numpy dtype for this tensor."""
97
- # This will now correctly return a bfloat16 dtype object when dtype_str is 'bfloat16'
98
- return DTYPE_ZT_TO_NP.get(self.dtype_str)
120
+ dtype_str = self.dtype_str
121
+ dt = DTYPE_ZT_TO_NP.get(dtype_str)
122
+ if dt is None:
123
+ if dtype_str == 'bfloat16':
124
+ raise ZTensorError(
125
+ "Cannot read 'bfloat16' tensor as NumPy array because the 'ml_dtypes' "
126
+ "package is not installed. Please install it to proceed."
127
+ )
128
+ raise ZTensorError(f"Unsupported or unknown dtype string '{dtype_str}' found in tensor metadata.")
129
+ return dt
99
130
 
100
- # RE-ENABLED: This property now works because the underlying FFI functions are available.
101
131
  @property
102
132
  def shape(self):
103
133
  if self._shape is None:
@@ -106,7 +136,6 @@ class TensorMetadata:
106
136
  shape_data_ptr = lib.ztensor_metadata_get_shape_data(self._ptr)
107
137
  _check_ptr(shape_data_ptr, "get_shape_data")
108
138
  self._shape = tuple(shape_data_ptr[i] for i in range(shape_len))
109
- # Free the array that was allocated on the Rust side.
110
139
  lib.ztensor_free_u64_array(shape_data_ptr, shape_len)
111
140
  else:
112
141
  self._shape = tuple()
@@ -120,15 +149,12 @@ class Reader:
120
149
  path_bytes = file_path.encode('utf-8')
121
150
  ptr = lib.ztensor_reader_open(path_bytes)
122
151
  _check_ptr(ptr, f"Reader open: {file_path}")
123
- # The pointer is automatically garbage collected by CFFI.
124
152
  self._ptr = ffi.gc(ptr, lib.ztensor_reader_free)
125
153
 
126
154
  def __enter__(self):
127
155
  return self
128
156
 
129
157
  def __exit__(self, exc_type, exc_val, exc_tb):
130
- # CFFI's garbage collector handles freeing the reader pointer automatically.
131
- # No explicit free is needed here, simplifying the context manager.
132
158
  self._ptr = None
133
159
 
134
160
  def get_metadata(self, name: str) -> TensorMetadata:
@@ -139,32 +165,59 @@ class Reader:
139
165
  _check_ptr(meta_ptr, f"get_metadata: {name}")
140
166
  return TensorMetadata(meta_ptr)
141
167
 
142
- def read_tensor(self, name: str) -> np.ndarray:
143
- """Reads a tensor by name and returns it as a NumPy array (zero-copy)."""
144
- metadata = self.get_metadata(name)
168
+ def read_tensor(self, name: str, to: str = 'numpy'):
169
+ """
170
+ Reads a tensor by name and returns it as a NumPy array or PyTorch tensor.
171
+ This is a zero-copy operation for both formats (for CPU tensors).
145
172
 
146
- # This check is now more robust.
147
- dtype = metadata.dtype
148
- if dtype is None:
149
- raise ZTensorError(f"Unsupported ztensor dtype: '{metadata.dtype_str}'")
173
+ Args:
174
+ name (str): The name of the tensor to read.
175
+ to (str): The desired output format. Either 'numpy' (default) or 'torch'.
150
176
 
177
+ Returns:
178
+ np.ndarray or torch.Tensor: The tensor data.
179
+ """
180
+ if to not in ['numpy', 'torch']:
181
+ raise ValueError(f"Unsupported format: '{to}'. Choose 'numpy' or 'torch'.")
182
+
183
+ metadata = self.get_metadata(name)
151
184
  view_ptr = lib.ztensor_reader_read_tensor_view(self._ptr, metadata._ptr)
152
185
  _check_ptr(view_ptr, f"read_tensor: {name}")
153
186
 
154
187
  # Let CFFI manage the lifetime of the view pointer.
155
188
  view_ptr = ffi.gc(view_ptr, lib.ztensor_free_tensor_view)
156
189
 
157
- # CORRECTED: Create array using the subclass, which handles reshaping and memory.
158
- # With the correct bfloat16 dtype, np.frombuffer will now correctly interpret
159
- # the buffer, creating an array with the right number of elements (2048).
160
- array = _ZTensorView(
161
- buffer=ffi.buffer(view_ptr.data, view_ptr.len),
162
- dtype=dtype,
163
- shape=metadata.shape,
164
- view_ptr=view_ptr
165
- )
190
+ if to == 'numpy':
191
+ # Use the custom _ZTensorView to safely manage the FFI pointer lifetime.
192
+ return _ZTensorView(
193
+ buffer=ffi.buffer(view_ptr.data, view_ptr.len),
194
+ dtype=metadata.dtype, # This property raises on unsupported dtypes
195
+ shape=metadata.shape,
196
+ view_ptr=view_ptr
197
+ )
198
+
199
+ elif to == 'torch':
200
+ if not TORCH_AVAILABLE:
201
+ raise ZTensorError("PyTorch is not installed. Cannot return a torch tensor.")
202
+
203
+ # Get the corresponding torch dtype, raising if not supported.
204
+ torch_dtype = DTYPE_ZT_TO_TORCH.get(metadata.dtype_str)
205
+ if torch_dtype is None:
206
+ raise ZTensorError(
207
+ f"Cannot read tensor '{name}' as a PyTorch tensor. "
208
+ f"The dtype '{metadata.dtype_str}' is not supported by PyTorch."
209
+ )
210
+
211
+ # Create a tensor directly from the buffer to avoid numpy conversion issues.
212
+ buffer = ffi.buffer(view_ptr.data, view_ptr.len)
213
+ torch_tensor = torch.frombuffer(buffer, dtype=torch_dtype).reshape(metadata.shape)
166
214
 
167
- return array
215
+ # CRITICAL: Attach the memory owner to the tensor to manage its lifetime.
216
+ # This ensures the Rust memory (held by view_ptr) is not freed while the
217
+ # torch tensor is still in use.
218
+ torch_tensor._owner = view_ptr
219
+
220
+ return torch_tensor
168
221
 
169
222
 
170
223
  class Writer:
@@ -174,8 +227,6 @@ class Writer:
174
227
  path_bytes = file_path.encode('utf-8')
175
228
  ptr = lib.ztensor_writer_create(path_bytes)
176
229
  _check_ptr(ptr, f"Writer create: {file_path}")
177
- # The pointer is consumed by finalize, so we don't use ffi.gc here.
178
- # The writer should be freed via finalize or ztensor_writer_free if finalize fails.
179
230
  self._ptr = ptr
180
231
  self._finalized = False
181
232
 
@@ -183,46 +234,67 @@ class Writer:
183
234
  return self
184
235
 
185
236
  def __exit__(self, exc_type, exc_val, exc_tb):
186
- # Automatically finalize on exit if not already done and no error occurred.
187
237
  if self._ptr and not self._finalized:
188
238
  if exc_type is None:
189
239
  self.finalize()
190
240
  else:
191
- # If an error occurred, don't finalize, just free the writer to prevent leaks.
192
241
  lib.ztensor_writer_free(self._ptr)
193
242
  self._ptr = None
194
243
 
195
- def add_tensor(self, name: str, tensor: np.ndarray):
196
- """Adds a NumPy array as a tensor to the file."""
197
- if not self._ptr: raise ZTensorError("Writer is closed or finalized.")
244
+ def add_tensor(self, name: str, tensor):
245
+ """
246
+ Adds a NumPy or PyTorch tensor to the file (zero-copy).
247
+ Supports float16 and bfloat16 types.
198
248
 
199
- name_bytes = name.encode('utf-8')
200
- tensor = np.ascontiguousarray(tensor) # Ensure data is contiguous.
249
+ Args:
250
+ name (str): The name of the tensor to add.
251
+ tensor (np.ndarray or torch.Tensor): The tensor data to write.
252
+ """
253
+ if not self._ptr: raise ZTensorError("Writer is closed or finalized.")
201
254
 
202
- shape_array = np.array(tensor.shape, dtype=np.uint64)
203
- shape_ptr = ffi.cast("uint64_t*", shape_array.ctypes.data)
255
+ # --- Polymorphic tensor handling ---
256
+ if isinstance(tensor, np.ndarray):
257
+ tensor = np.ascontiguousarray(tensor)
258
+ shape = tensor.shape
259
+ dtype_str = DTYPE_NP_TO_ZT.get(tensor.dtype)
260
+ data_ptr = ffi.cast("unsigned char*", tensor.ctypes.data)
261
+ nbytes = tensor.nbytes
262
+
263
+ elif TORCH_AVAILABLE and isinstance(tensor, torch.Tensor):
264
+ if tensor.is_cuda:
265
+ raise ZTensorError("Cannot write directly from a CUDA tensor. Copy to CPU first using .cpu().")
266
+ tensor = tensor.contiguous()
267
+ shape = tuple(tensor.shape)
268
+ dtype_str = DTYPE_TORCH_TO_ZT.get(tensor.dtype)
269
+ data_ptr = ffi.cast("unsigned char*", tensor.data_ptr())
270
+ nbytes = tensor.numel() * tensor.element_size()
271
+
272
+ else:
273
+ supported = "np.ndarray" + (" or torch.Tensor" if TORCH_AVAILABLE else "")
274
+ raise TypeError(f"Unsupported tensor type: {type(tensor)}. Must be {supported}.")
204
275
 
205
- # The updated DTYPE_NP_TO_ZT will now correctly handle bfloat16 tensors.
206
- dtype_str = DTYPE_NP_TO_ZT.get(tensor.dtype)
207
276
  if not dtype_str:
208
- raise ZTensorError(f"Unsupported NumPy dtype: {tensor.dtype}")
209
- dtype_bytes = dtype_str.encode('utf-8')
277
+ msg = f"Unsupported dtype: {tensor.dtype}."
278
+ if 'bfloat16' in str(tensor.dtype) and not ML_DTYPES_AVAILABLE:
279
+ msg += " For NumPy bfloat16 support, please install the 'ml_dtypes' package."
280
+ raise ZTensorError(msg)
210
281
 
211
- # CORRECTED: Cast to `unsigned char*` to match the CFFI definition and Rust FFI.
212
- data_ptr = ffi.cast("unsigned char*", tensor.ctypes.data)
282
+ name_bytes = name.encode('utf-8')
283
+ shape_array = np.array(shape, dtype=np.uint64)
284
+ shape_ptr = ffi.cast("uint64_t*", shape_array.ctypes.data)
285
+ dtype_bytes = dtype_str.encode('utf-8')
213
286
 
214
287
  status = lib.ztensor_writer_add_tensor(
215
- self._ptr, name_bytes, shape_ptr, len(tensor.shape),
216
- dtype_bytes, data_ptr, tensor.nbytes
288
+ self._ptr, name_bytes, shape_ptr, len(shape),
289
+ dtype_bytes, data_ptr, nbytes
217
290
  )
218
291
  _check_status(status, f"add_tensor: {name}")
219
292
 
220
293
  def finalize(self):
221
294
  """Finalizes the zTensor file, writing the metadata index."""
222
295
  if not self._ptr: raise ZTensorError("Writer is already closed or finalized.")
223
-
224
296
  status = lib.ztensor_writer_finalize(self._ptr)
225
- self._ptr = None # The writer pointer is consumed and invalidated by the Rust call.
297
+ self._ptr = None
226
298
  self._finalized = True
227
299
  _check_status(status, "finalize")
228
300
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ztensor
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: MIT License
@@ -1,9 +1,9 @@
1
- ztensor-0.1.1.dist-info/METADATA,sha256=dpaZnxqi7u30Z3uH641XoVg_DeH5k25UamSBjzseBos,4446
2
- ztensor-0.1.1.dist-info/WHEEL,sha256=0pRFF2QEyOo7g4FVclldsFiaPr__8E5qQ8hSd30TcgM,102
3
- ztensor-0.1.1.dist-info/licenses/LICENSE,sha256=AoeyV1LzTyOz9sbr6uOzk_P0lW963DvhJHnVNVQlI3Y,1063
1
+ ztensor-0.1.2.dist-info/METADATA,sha256=zJKbwLwUsstAXn5NM2OxH7i4r1En_3gQyNjkSC69lK8,4446
2
+ ztensor-0.1.2.dist-info/WHEEL,sha256=0pRFF2QEyOo7g4FVclldsFiaPr__8E5qQ8hSd30TcgM,102
3
+ ztensor-0.1.2.dist-info/licenses/LICENSE,sha256=AoeyV1LzTyOz9sbr6uOzk_P0lW963DvhJHnVNVQlI3Y,1063
4
4
  ztensor.libs/libgcc_s-b5472b99.so.1,sha256=wh8CpjXz9IccAyeERcB7YDEx7NH2jF-PykwOyYNeRRI,453841
5
- ztensor/__init__.py,sha256=pzpS0XNcpBb762OPUG7-3raEJxz7P99KX6CGAdh-DS0,9086
5
+ ztensor/__init__.py,sha256=60wRwWWC_urfj2baZVHy_rGoYEFTPwW3pcSPmrvgJKs,11298
6
6
  ztensor/ztensor/__init__.py,sha256=sIpB0pJYFX20TdZapIoPMxqMz37wKJyxCkAlTemWDq4,140
7
7
  ztensor/ztensor/ffi.py,sha256=J7CG26lx0Xu0IjYR7GWev-BDgGQXseKA4cjkh0IhnLE,2746
8
8
  ztensor/ztensor/libztensor.so,sha256=p5LH8-ZnPfUbhV8R7gYmRCPrxHIb5rm3OUPNW8PDenM,1750593
9
- ztensor-0.1.1.dist-info/RECORD,,
9
+ ztensor-0.1.2.dist-info/RECORD,,