ztensor 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ztensor might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ztensor
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: MIT License
@@ -9,6 +9,7 @@ Classifier: Intended Audience :: Developers
9
9
  Classifier: Topic :: Scientific/Engineering
10
10
  Requires-Dist: numpy
11
11
  Requires-Dist: cffi
12
+ Requires-Dist: ml-dtypes
12
13
  License-File: LICENSE
13
14
  Summary: Python bindings for the zTensor library.
14
15
  Author: In Gim <in.gim@yale.edu>
@@ -10,7 +10,7 @@ build-backend = "maturin"
10
10
  [project]
11
11
  # The name of your package on PyPI.
12
12
  name = "ztensor"
13
- version = "0.1.0"
13
+ version = "0.1.2"
14
14
  description = "Python bindings for the zTensor library."
15
15
  readme = "README.md" # It's good practice to have a README.
16
16
  authors = [
@@ -27,7 +27,8 @@ classifiers = [
27
27
  ]
28
28
  dependencies = [
29
29
  "numpy",
30
- "cffi"
30
+ "cffi",
31
+ "ml_dtypes"
31
32
  ]
32
33
 
33
34
  [project.urls]
@@ -2,9 +2,17 @@ import os
2
2
 
3
3
  import numpy as np
4
4
  from ztensor import Writer, Reader, ZTensorError
5
+ import torch
5
6
 
7
+ model_path = "llama1b.zt"
6
8
 
7
- file_path = "test_tensors.zt"
9
+ with Reader(model_path) as reader:
10
+ tensor_data_np = reader.read_tensor("model.layers.1.self_attn.k_proj.weight", to="torch")
11
+ print(tensor_data_np.dtype)
12
+
13
+
14
+
15
+ file_path = "../test_tensors.zt"
8
16
 
9
17
  # --- Write Tensors ---
10
18
  print(f"--- Writing to {file_path} ---")
@@ -1,6 +1,23 @@
1
1
  import numpy as np
2
2
  from .ztensor import ffi, lib
3
3
 
4
+ # --- Optional PyTorch Import ---
5
+ try:
6
+ import torch
7
+
8
+ TORCH_AVAILABLE = True
9
+ except ImportError:
10
+ TORCH_AVAILABLE = False
11
+
12
+ # --- Optional ml_dtypes for bfloat16 in NumPy ---
13
+ try:
14
+ from ml_dtypes import bfloat16 as np_bfloat16
15
+
16
+ ML_DTYPES_AVAILABLE = True
17
+ except ImportError:
18
+ np_bfloat16 = None
19
+ ML_DTYPES_AVAILABLE = False
20
+
4
21
 
5
22
  # --- Pythonic Wrapper ---
6
23
  class ZTensorError(Exception):
@@ -11,7 +28,6 @@ class ZTensorError(Exception):
11
28
  # A custom ndarray subclass to safely manage the lifetime of the CFFI pointer.
12
29
  class _ZTensorView(np.ndarray):
13
30
  def __new__(cls, buffer, dtype, shape, view_ptr):
14
- # Create an array from the buffer, reshape it, and cast it to our custom type.
15
31
  obj = np.frombuffer(buffer, dtype=dtype).reshape(shape).view(cls)
16
32
  # Attach the object that owns the memory to an attribute.
17
33
  obj._owner = view_ptr
@@ -44,23 +60,36 @@ def _check_status(status, func_name=""):
44
60
  raise ZTensorError(f"Error in {func_name}: {_get_last_error()}")
45
61
 
46
62
 
47
- # Type Mappings between NumPy and ztensor
63
+ # --- Type Mappings ---
64
+ # NumPy Mappings
48
65
  DTYPE_NP_TO_ZT = {
49
- np.dtype('float64'): 'float64', np.dtype('float32'): 'float32',
66
+ np.dtype('float64'): 'float64', np.dtype('float32'): 'float32', np.dtype('float16'): 'float16',
50
67
  np.dtype('int64'): 'int64', np.dtype('int32'): 'int32',
51
68
  np.dtype('int16'): 'int16', np.dtype('int8'): 'int8',
52
69
  np.dtype('uint64'): 'uint64', np.dtype('uint32'): 'uint32',
53
70
  np.dtype('uint16'): 'uint16', np.dtype('uint8'): 'uint8',
54
71
  np.dtype('bool'): 'bool',
55
72
  }
73
+ if ML_DTYPES_AVAILABLE:
74
+ DTYPE_NP_TO_ZT[np.dtype(np_bfloat16)] = 'bfloat16'
56
75
  DTYPE_ZT_TO_NP = {v: k for k, v in DTYPE_NP_TO_ZT.items()}
57
76
 
77
+ # PyTorch Mappings (if available)
78
+ if TORCH_AVAILABLE:
79
+ DTYPE_TORCH_TO_ZT = {
80
+ torch.float64: 'float64', torch.float32: 'float32', torch.float16: 'float16',
81
+ torch.bfloat16: 'bfloat16',
82
+ torch.int64: 'int64', torch.int32: 'int32',
83
+ torch.int16: 'int16', torch.int8: 'int8',
84
+ torch.uint8: 'uint8', torch.bool: 'bool',
85
+ }
86
+ DTYPE_ZT_TO_TORCH = {v: k for k, v in DTYPE_TORCH_TO_ZT.items()}
87
+
58
88
 
59
89
  class TensorMetadata:
60
90
  """A Pythonic wrapper around the CTensorMetadata pointer."""
61
91
 
62
92
  def __init__(self, meta_ptr):
63
- # The pointer is now automatically garbage collected by CFFI when this object dies.
64
93
  self._ptr = ffi.gc(meta_ptr, lib.ztensor_metadata_free)
65
94
  _check_ptr(self._ptr, "TensorMetadata constructor")
66
95
  self._name = None
@@ -72,7 +101,6 @@ class TensorMetadata:
72
101
  if self._name is None:
73
102
  name_ptr = lib.ztensor_metadata_get_name(self._ptr)
74
103
  _check_ptr(name_ptr, "get_name")
75
- # ffi.string creates a copy, so we must free the Rust-allocated original.
76
104
  self._name = ffi.string(name_ptr).decode('utf-8')
77
105
  lib.ztensor_free_string(name_ptr)
78
106
  return self._name
@@ -82,7 +110,6 @@ class TensorMetadata:
82
110
  if self._dtype_str is None:
83
111
  dtype_ptr = lib.ztensor_metadata_get_dtype_str(self._ptr)
84
112
  _check_ptr(dtype_ptr, "get_dtype_str")
85
- # ffi.string creates a copy, so we must free the Rust-allocated original.
86
113
  self._dtype_str = ffi.string(dtype_ptr).decode('utf-8')
87
114
  lib.ztensor_free_string(dtype_ptr)
88
115
  return self._dtype_str
@@ -90,9 +117,17 @@ class TensorMetadata:
90
117
  @property
91
118
  def dtype(self):
92
119
  """Returns the numpy dtype for this tensor."""
93
- return DTYPE_ZT_TO_NP.get(self.dtype_str)
120
+ dtype_str = self.dtype_str
121
+ dt = DTYPE_ZT_TO_NP.get(dtype_str)
122
+ if dt is None:
123
+ if dtype_str == 'bfloat16':
124
+ raise ZTensorError(
125
+ "Cannot read 'bfloat16' tensor as NumPy array because the 'ml_dtypes' "
126
+ "package is not installed. Please install it to proceed."
127
+ )
128
+ raise ZTensorError(f"Unsupported or unknown dtype string '{dtype_str}' found in tensor metadata.")
129
+ return dt
94
130
 
95
- # RE-ENABLED: This property now works because the underlying FFI functions are available.
96
131
  @property
97
132
  def shape(self):
98
133
  if self._shape is None:
@@ -101,7 +136,6 @@ class TensorMetadata:
101
136
  shape_data_ptr = lib.ztensor_metadata_get_shape_data(self._ptr)
102
137
  _check_ptr(shape_data_ptr, "get_shape_data")
103
138
  self._shape = tuple(shape_data_ptr[i] for i in range(shape_len))
104
- # Free the array that was allocated on the Rust side.
105
139
  lib.ztensor_free_u64_array(shape_data_ptr, shape_len)
106
140
  else:
107
141
  self._shape = tuple()
@@ -115,15 +149,12 @@ class Reader:
115
149
  path_bytes = file_path.encode('utf-8')
116
150
  ptr = lib.ztensor_reader_open(path_bytes)
117
151
  _check_ptr(ptr, f"Reader open: {file_path}")
118
- # The pointer is automatically garbage collected by CFFI.
119
152
  self._ptr = ffi.gc(ptr, lib.ztensor_reader_free)
120
153
 
121
154
  def __enter__(self):
122
155
  return self
123
156
 
124
157
  def __exit__(self, exc_type, exc_val, exc_tb):
125
- # CFFI's garbage collector handles freeing the reader pointer automatically.
126
- # No explicit free is needed here, simplifying the context manager.
127
158
  self._ptr = None
128
159
 
129
160
  def get_metadata(self, name: str) -> TensorMetadata:
@@ -134,8 +165,21 @@ class Reader:
134
165
  _check_ptr(meta_ptr, f"get_metadata: {name}")
135
166
  return TensorMetadata(meta_ptr)
136
167
 
137
- def read_tensor(self, name: str) -> np.ndarray:
138
- """Reads a tensor by name and returns it as a NumPy array (zero-copy)."""
168
+ def read_tensor(self, name: str, to: str = 'numpy'):
169
+ """
170
+ Reads a tensor by name and returns it as a NumPy array or PyTorch tensor.
171
+ This is a zero-copy operation for both formats (for CPU tensors).
172
+
173
+ Args:
174
+ name (str): The name of the tensor to read.
175
+ to (str): The desired output format. Either 'numpy' (default) or 'torch'.
176
+
177
+ Returns:
178
+ np.ndarray or torch.Tensor: The tensor data.
179
+ """
180
+ if to not in ['numpy', 'torch']:
181
+ raise ValueError(f"Unsupported format: '{to}'. Choose 'numpy' or 'torch'.")
182
+
139
183
  metadata = self.get_metadata(name)
140
184
  view_ptr = lib.ztensor_reader_read_tensor_view(self._ptr, metadata._ptr)
141
185
  _check_ptr(view_ptr, f"read_tensor: {name}")
@@ -143,15 +187,37 @@ class Reader:
143
187
  # Let CFFI manage the lifetime of the view pointer.
144
188
  view_ptr = ffi.gc(view_ptr, lib.ztensor_free_tensor_view)
145
189
 
146
- # CORRECTED: Create array using the subclass, which handles reshaping and memory.
147
- array = _ZTensorView(
148
- buffer=ffi.buffer(view_ptr.data, view_ptr.len),
149
- dtype=metadata.dtype,
150
- shape=metadata.shape,
151
- view_ptr=view_ptr
152
- )
190
+ if to == 'numpy':
191
+ # Use the custom _ZTensorView to safely manage the FFI pointer lifetime.
192
+ return _ZTensorView(
193
+ buffer=ffi.buffer(view_ptr.data, view_ptr.len),
194
+ dtype=metadata.dtype, # This property raises on unsupported dtypes
195
+ shape=metadata.shape,
196
+ view_ptr=view_ptr
197
+ )
198
+
199
+ elif to == 'torch':
200
+ if not TORCH_AVAILABLE:
201
+ raise ZTensorError("PyTorch is not installed. Cannot return a torch tensor.")
202
+
203
+ # Get the corresponding torch dtype, raising if not supported.
204
+ torch_dtype = DTYPE_ZT_TO_TORCH.get(metadata.dtype_str)
205
+ if torch_dtype is None:
206
+ raise ZTensorError(
207
+ f"Cannot read tensor '{name}' as a PyTorch tensor. "
208
+ f"The dtype '{metadata.dtype_str}' is not supported by PyTorch."
209
+ )
210
+
211
+ # Create a tensor directly from the buffer to avoid numpy conversion issues.
212
+ buffer = ffi.buffer(view_ptr.data, view_ptr.len)
213
+ torch_tensor = torch.frombuffer(buffer, dtype=torch_dtype).reshape(metadata.shape)
153
214
 
154
- return array
215
+ # CRITICAL: Attach the memory owner to the tensor to manage its lifetime.
216
+ # This ensures the Rust memory (held by view_ptr) is not freed while the
217
+ # torch tensor is still in use.
218
+ torch_tensor._owner = view_ptr
219
+
220
+ return torch_tensor
155
221
 
156
222
 
157
223
  class Writer:
@@ -161,8 +227,6 @@ class Writer:
161
227
  path_bytes = file_path.encode('utf-8')
162
228
  ptr = lib.ztensor_writer_create(path_bytes)
163
229
  _check_ptr(ptr, f"Writer create: {file_path}")
164
- # The pointer is consumed by finalize, so we don't use ffi.gc here.
165
- # The writer should be freed via finalize or ztensor_writer_free if finalize fails.
166
230
  self._ptr = ptr
167
231
  self._finalized = False
168
232
 
@@ -170,45 +234,67 @@ class Writer:
170
234
  return self
171
235
 
172
236
  def __exit__(self, exc_type, exc_val, exc_tb):
173
- # Automatically finalize on exit if not already done and no error occurred.
174
237
  if self._ptr and not self._finalized:
175
238
  if exc_type is None:
176
239
  self.finalize()
177
240
  else:
178
- # If an error occurred, don't finalize, just free the writer to prevent leaks.
179
241
  lib.ztensor_writer_free(self._ptr)
180
242
  self._ptr = None
181
243
 
182
- def add_tensor(self, name: str, tensor: np.ndarray):
183
- """Adds a NumPy array as a tensor to the file."""
184
- if not self._ptr: raise ZTensorError("Writer is closed or finalized.")
244
+ def add_tensor(self, name: str, tensor):
245
+ """
246
+ Adds a NumPy or PyTorch tensor to the file (zero-copy).
247
+ Supports float16 and bfloat16 types.
185
248
 
186
- name_bytes = name.encode('utf-8')
187
- tensor = np.ascontiguousarray(tensor) # Ensure data is contiguous.
249
+ Args:
250
+ name (str): The name of the tensor to add.
251
+ tensor (np.ndarray or torch.Tensor): The tensor data to write.
252
+ """
253
+ if not self._ptr: raise ZTensorError("Writer is closed or finalized.")
188
254
 
189
- shape_array = np.array(tensor.shape, dtype=np.uint64)
190
- shape_ptr = ffi.cast("uint64_t*", shape_array.ctypes.data)
255
+ # --- Polymorphic tensor handling ---
256
+ if isinstance(tensor, np.ndarray):
257
+ tensor = np.ascontiguousarray(tensor)
258
+ shape = tensor.shape
259
+ dtype_str = DTYPE_NP_TO_ZT.get(tensor.dtype)
260
+ data_ptr = ffi.cast("unsigned char*", tensor.ctypes.data)
261
+ nbytes = tensor.nbytes
262
+
263
+ elif TORCH_AVAILABLE and isinstance(tensor, torch.Tensor):
264
+ if tensor.is_cuda:
265
+ raise ZTensorError("Cannot write directly from a CUDA tensor. Copy to CPU first using .cpu().")
266
+ tensor = tensor.contiguous()
267
+ shape = tuple(tensor.shape)
268
+ dtype_str = DTYPE_TORCH_TO_ZT.get(tensor.dtype)
269
+ data_ptr = ffi.cast("unsigned char*", tensor.data_ptr())
270
+ nbytes = tensor.numel() * tensor.element_size()
271
+
272
+ else:
273
+ supported = "np.ndarray" + (" or torch.Tensor" if TORCH_AVAILABLE else "")
274
+ raise TypeError(f"Unsupported tensor type: {type(tensor)}. Must be {supported}.")
191
275
 
192
- dtype_str = DTYPE_NP_TO_ZT.get(tensor.dtype)
193
276
  if not dtype_str:
194
- raise ZTensorError(f"Unsupported NumPy dtype: {tensor.dtype}")
195
- dtype_bytes = dtype_str.encode('utf-8')
277
+ msg = f"Unsupported dtype: {tensor.dtype}."
278
+ if 'bfloat16' in str(tensor.dtype) and not ML_DTYPES_AVAILABLE:
279
+ msg += " For NumPy bfloat16 support, please install the 'ml_dtypes' package."
280
+ raise ZTensorError(msg)
196
281
 
197
- # CORRECTED: Cast to `unsigned char*` to match the CFFI definition and Rust FFI.
198
- data_ptr = ffi.cast("unsigned char*", tensor.ctypes.data)
282
+ name_bytes = name.encode('utf-8')
283
+ shape_array = np.array(shape, dtype=np.uint64)
284
+ shape_ptr = ffi.cast("uint64_t*", shape_array.ctypes.data)
285
+ dtype_bytes = dtype_str.encode('utf-8')
199
286
 
200
287
  status = lib.ztensor_writer_add_tensor(
201
- self._ptr, name_bytes, shape_ptr, len(tensor.shape),
202
- dtype_bytes, data_ptr, tensor.nbytes
288
+ self._ptr, name_bytes, shape_ptr, len(shape),
289
+ dtype_bytes, data_ptr, nbytes
203
290
  )
204
291
  _check_status(status, f"add_tensor: {name}")
205
292
 
206
293
  def finalize(self):
207
294
  """Finalizes the zTensor file, writing the metadata index."""
208
295
  if not self._ptr: raise ZTensorError("Writer is already closed or finalized.")
209
-
210
296
  status = lib.ztensor_writer_finalize(self._ptr)
211
- self._ptr = None # The writer pointer is consumed and invalidated by the Rust call.
297
+ self._ptr = None
212
298
  self._finalized = True
213
299
  _check_status(status, "finalize")
214
300
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes