ztensor 1.1.1__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ztensor/__init__.py ADDED
@@ -0,0 +1,710 @@
1
+ import numpy as np
2
+ import importlib
3
+ from .ztensor import ffi, lib
4
+ from typing import Union
5
+
6
+ # --- Optional PyTorch Import ---
7
+ try:
8
+ _torch = importlib.import_module("torch")
9
+ TORCH_AVAILABLE = True
10
+ except ImportError:
11
+ _torch = None
12
+ TORCH_AVAILABLE = False
13
+
14
+ # --- Optional ml_dtypes for bfloat16 in NumPy ---
15
+ try:
16
+ from ml_dtypes import bfloat16 as np_bfloat16
17
+ ML_DTYPES_AVAILABLE = True
18
+ except ImportError:
19
+ np_bfloat16 = None
20
+ ML_DTYPES_AVAILABLE = False
21
+
22
+
23
+ # --- Pythonic Wrapper ---
24
+ class ZTensorError(Exception):
25
+ """Custom exception for ztensor-related errors."""
26
+ pass
27
+
28
+
29
+ class _ZTensorView(np.ndarray):
30
+ """Custom ndarray subclass to safely manage CFFI pointer lifetime."""
31
+ def __new__(cls, buffer, dtype, shape, view_ptr):
32
+ obj = np.frombuffer(buffer, dtype=dtype).reshape(shape).view(cls)
33
+ obj._owner = view_ptr
34
+ return obj
35
+
36
+ def __array_finalize__(self, obj):
37
+ if obj is None: return
38
+ self._owner = getattr(obj, '_owner', None)
39
+
40
+
41
+ def _get_last_error():
42
+ """Retrieves the last error message from the Rust library."""
43
+ err_msg_ptr = lib.ztensor_last_error_message()
44
+ if err_msg_ptr != ffi.NULL:
45
+ return ffi.string(err_msg_ptr).decode('utf-8')
46
+ return "Unknown FFI error"
47
+
48
+
49
+ def _check_ptr(ptr, func_name=""):
50
+ """Checks if a pointer from the FFI is null and raises an error if it is."""
51
+ if ptr == ffi.NULL:
52
+ raise ZTensorError(f"Error in {func_name}: {_get_last_error()}")
53
+ return ptr
54
+
55
+
56
+ def _check_status(status, func_name=""):
57
+ """Checks the integer status code from an FFI call and raises on failure."""
58
+ if status != 0:
59
+ raise ZTensorError(f"Error in {func_name}: {_get_last_error()}")
60
+
61
+
62
+ # --- Type Mappings ---
63
+ DTYPE_NP_TO_ZT = {
64
+ np.dtype('float64'): 'float64', np.dtype('float32'): 'float32', np.dtype('float16'): 'float16',
65
+ np.dtype('int64'): 'int64', np.dtype('int32'): 'int32',
66
+ np.dtype('int16'): 'int16', np.dtype('int8'): 'int8',
67
+ np.dtype('uint64'): 'uint64', np.dtype('uint32'): 'uint32',
68
+ np.dtype('uint16'): 'uint16', np.dtype('uint8'): 'uint8',
69
+ np.dtype('bool'): 'bool',
70
+ }
71
+ if ML_DTYPES_AVAILABLE:
72
+ DTYPE_NP_TO_ZT[np.dtype(np_bfloat16)] = 'bfloat16'
73
+ DTYPE_ZT_TO_NP = {v: k for k, v in DTYPE_NP_TO_ZT.items()}
74
+
75
+ if TORCH_AVAILABLE:
76
+ DTYPE_TORCH_TO_ZT = {
77
+ _torch.float64: 'float64', _torch.float32: 'float32', _torch.float16: 'float16',
78
+ _torch.bfloat16: 'bfloat16',
79
+ _torch.int64: 'int64', _torch.int32: 'int32',
80
+ _torch.int16: 'int16', _torch.int8: 'int8',
81
+ _torch.uint8: 'uint8', _torch.bool: 'bool',
82
+ }
83
+ DTYPE_ZT_TO_TORCH = {v: k for k, v in DTYPE_TORCH_TO_ZT.items()}
84
+
85
+
86
+ class TensorMetadata:
87
+ """A Pythonic wrapper around the CTensorMetadata pointer."""
88
+
89
+ def __init__(self, meta_ptr):
90
+ self._ptr = ffi.gc(meta_ptr, lib.ztensor_metadata_free)
91
+ _check_ptr(self._ptr, "TensorMetadata constructor")
92
+ self._name = None
93
+ self._dtype_str = None
94
+ self._shape = None
95
+ self._offset = None
96
+ self._size = None
97
+ self._layout = None
98
+ self._encoding = None
99
+ self._endianness = "not_checked"
100
+ self._checksum = "not_checked"
101
+
102
+ def __repr__(self):
103
+ return f"<TensorMetadata name='{self.name}' shape={self.shape} dtype='{self.dtype_str}'>"
104
+
105
+ @property
106
+ def name(self):
107
+ """The name of the tensor."""
108
+ if self._name is None:
109
+ name_ptr = lib.ztensor_metadata_get_name(self._ptr)
110
+ _check_ptr(name_ptr, "get_name")
111
+ self._name = ffi.string(name_ptr).decode('utf-8')
112
+ lib.ztensor_free_string(name_ptr)
113
+ return self._name
114
+
115
+ @property
116
+ def dtype_str(self):
117
+ """The zTensor dtype string (e.g., 'float32')."""
118
+ if self._dtype_str is None:
119
+ dtype_ptr = lib.ztensor_metadata_get_dtype_str(self._ptr)
120
+ _check_ptr(dtype_ptr, "get_dtype_str")
121
+ self._dtype_str = ffi.string(dtype_ptr).decode('utf-8')
122
+ lib.ztensor_free_string(dtype_ptr)
123
+ return self._dtype_str
124
+
125
+ @property
126
+ def dtype(self):
127
+ """The numpy dtype for this tensor."""
128
+ dtype_str = self.dtype_str
129
+ dt = DTYPE_ZT_TO_NP.get(dtype_str)
130
+ if dt is None:
131
+ if dtype_str == 'bfloat16':
132
+ raise ZTensorError(
133
+ "Cannot read 'bfloat16' tensor as NumPy array because the 'ml_dtypes' "
134
+ "package is not installed. Please install it to proceed."
135
+ )
136
+ raise ZTensorError(f"Unsupported or unknown dtype string '{dtype_str}' found in tensor metadata.")
137
+ return dt
138
+
139
+ @property
140
+ def shape(self):
141
+ """The shape of the tensor as a tuple."""
142
+ if self._shape is None:
143
+ shape_len = lib.ztensor_metadata_get_shape_len(self._ptr)
144
+ if shape_len > 0:
145
+ shape_data_ptr = lib.ztensor_metadata_get_shape_data(self._ptr)
146
+ _check_ptr(shape_data_ptr, "get_shape_data")
147
+ self._shape = tuple(shape_data_ptr[i] for i in range(shape_len))
148
+ lib.ztensor_free_u64_array(shape_data_ptr, shape_len)
149
+ else:
150
+ self._shape = tuple()
151
+ return self._shape
152
+
153
+ @property
154
+ def offset(self):
155
+ """The on-disk offset of the tensor data in bytes."""
156
+ if self._offset is None:
157
+ self._offset = lib.ztensor_metadata_get_offset(self._ptr)
158
+ return self._offset
159
+
160
+ @property
161
+ def size(self):
162
+ """The on-disk size of the tensor data in bytes (can be compressed size)."""
163
+ if self._size is None:
164
+ self._size = lib.ztensor_metadata_get_size(self._ptr)
165
+ return self._size
166
+
167
+ @property
168
+ def layout(self):
169
+ """The tensor layout as a string (e.g., 'dense')."""
170
+ if self._layout is None:
171
+ layout_ptr = lib.ztensor_metadata_get_layout_str(self._ptr)
172
+ _check_ptr(layout_ptr, "get_layout_str")
173
+ self._layout = ffi.string(layout_ptr).decode('utf-8')
174
+ lib.ztensor_free_string(layout_ptr)
175
+ return self._layout
176
+
177
+ @property
178
+ def encoding(self):
179
+ """The tensor encoding as a string (e.g., 'raw', 'zstd')."""
180
+ if self._encoding is None:
181
+ encoding_ptr = lib.ztensor_metadata_get_encoding_str(self._ptr)
182
+ if encoding_ptr == ffi.NULL:
183
+ self._encoding = None
184
+ else:
185
+ self._encoding = ffi.string(encoding_ptr).decode('utf-8')
186
+ lib.ztensor_free_string(encoding_ptr)
187
+ return self._encoding
188
+
189
+ @property
190
+ def endianness(self):
191
+ """The data endianness ('little', 'big') if applicable, else None."""
192
+ if self._endianness == "not_checked":
193
+ endian_ptr = lib.ztensor_metadata_get_data_endianness_str(self._ptr)
194
+ if endian_ptr == ffi.NULL:
195
+ self._endianness = None
196
+ else:
197
+ self._endianness = ffi.string(endian_ptr).decode('utf-8')
198
+ lib.ztensor_free_string(endian_ptr)
199
+ return self._endianness
200
+
201
+ @property
202
+ def checksum(self):
203
+ """The checksum string if present, else None."""
204
+ if self._checksum == "not_checked":
205
+ checksum_ptr = lib.ztensor_metadata_get_checksum_str(self._ptr)
206
+ if checksum_ptr == ffi.NULL:
207
+ self._checksum = None
208
+ else:
209
+ self._checksum = ffi.string(checksum_ptr).decode('utf-8')
210
+ lib.ztensor_free_string(checksum_ptr)
211
+ return self._checksum
212
+
213
+
214
+ class Reader:
215
+ """A Pythonic context manager for reading zTensor files."""
216
+
217
+ def __init__(self, file_path):
218
+ path_bytes = file_path.encode('utf-8')
219
+ ptr = lib.ztensor_reader_open(path_bytes)
220
+ _check_ptr(ptr, f"Reader open: {file_path}")
221
+ self._ptr = ffi.gc(ptr, lib.ztensor_reader_free)
222
+ self._tensor_names_cache = None
223
+
224
+ def __enter__(self):
225
+ return self
226
+
227
+ def __exit__(self, exc_type, exc_val, exc_tb):
228
+ self._ptr = None
229
+ self._tensor_names_cache = None
230
+
231
+ def __len__(self):
232
+ """Returns the number of tensors in the file."""
233
+ if self._ptr is None: raise ZTensorError("Reader is closed.")
234
+ return lib.ztensor_reader_get_metadata_count(self._ptr)
235
+
236
+ def __iter__(self):
237
+ """Iterates over the metadata of all tensors in the file."""
238
+ if self._ptr is None: raise ZTensorError("Reader is closed.")
239
+ for name in self.tensor_names:
240
+ yield self.metadata(name)
241
+
242
+ def __getitem__(self, key):
243
+ """
244
+ Retrieves metadata (int key) or reads tensor data (str key).
245
+ """
246
+ if self._ptr is None: raise ZTensorError("Reader is closed.")
247
+
248
+ if isinstance(key, int):
249
+ if key >= len(self):
250
+ raise IndexError("Tensor index out of range")
251
+ meta_ptr = lib.ztensor_reader_get_metadata_by_index(self._ptr, key)
252
+ _check_ptr(meta_ptr, f"get_metadata_by_index: {key}")
253
+ return TensorMetadata(meta_ptr)
254
+ elif isinstance(key, str):
255
+ return self.read_tensor(key)
256
+ else:
257
+ raise TypeError(f"Invalid argument type for __getitem__: {type(key)}")
258
+
259
+ def __contains__(self, name: str) -> bool:
260
+ """Checks if a tensor with the given name exists in the file."""
261
+ return name in self.tensor_names
262
+
263
+ @property
264
+ def tensors(self) -> list[TensorMetadata]:
265
+ """Returns a list of all TensorMetadata objects in the file."""
266
+ return list(self)
267
+
268
+ @property
269
+ def tensor_names(self) -> list[str]:
270
+ """Returns a list of all tensor names in the file (cached)."""
271
+ if self._ptr is None: raise ZTensorError("Reader is closed.")
272
+ if self._tensor_names_cache is not None:
273
+ return self._tensor_names_cache
274
+
275
+ c_array_ptr = lib.ztensor_reader_get_all_tensor_names(self._ptr)
276
+ _check_ptr(c_array_ptr, "get_all_tensor_names")
277
+ c_array_ptr = ffi.gc(c_array_ptr, lib.ztensor_free_string_array)
278
+
279
+ self._tensor_names_cache = [ffi.string(c_array_ptr.strings[i]).decode('utf-8') for i in range(c_array_ptr.len)]
280
+ return self._tensor_names_cache
281
+
282
+ def metadata(self, name: str) -> TensorMetadata:
283
+ """Retrieves metadata for a tensor by its name."""
284
+ if self._ptr is None: raise ZTensorError("Reader is closed.")
285
+ name_bytes = name.encode('utf-8')
286
+ meta_ptr = lib.ztensor_reader_get_metadata_by_name(self._ptr, name_bytes)
287
+ _check_ptr(meta_ptr, f"metadata: {name}")
288
+ return TensorMetadata(meta_ptr)
289
+
290
+ # Legacy aliases
291
+ def list_tensors(self): return self.tensors
292
+ def get_tensor_names(self): return self.tensor_names
293
+ def get_metadata(self, name): return self.metadata(name)
294
+
295
+ def _read_component(self, tensor_name: str, component_name: str, dtype_func):
296
+ """Reads a specific component as a numpy array."""
297
+ t_name_bytes = tensor_name.encode('utf-8')
298
+ c_name_bytes = component_name.encode('utf-8')
299
+
300
+ view_ptr = lib.ztensor_reader_read_tensor_component(self._ptr, t_name_bytes, c_name_bytes)
301
+ _check_ptr(view_ptr, f"read_component: {tensor_name}.{component_name}")
302
+ view_ptr = ffi.gc(view_ptr, lib.ztensor_free_tensor_view)
303
+
304
+ buffer = ffi.buffer(view_ptr.data, view_ptr.len)
305
+ arr = np.frombuffer(buffer, dtype=dtype_func())
306
+ return arr, view_ptr
307
+
308
+ def read_tensor(self, name: str, to: str = 'numpy', verify_checksum: bool = False):
309
+ """
310
+ Reads a tensor by name and returns it as a NumPy array or PyTorch tensor.
311
+
312
+ Args:
313
+ name (str): The name of the tensor to read.
314
+ to (str): The desired output format. 'numpy' or 'torch'.
315
+ verify_checksum (bool): If True, verify checksums during read (slower). Default: False.
316
+
317
+ Returns:
318
+ Union[np.ndarray, torch.Tensor, scipy.sparse.spmatrix, torch.sparse_coo_tensor]
319
+ """
320
+ if self._ptr is None: raise ZTensorError("Reader is closed.")
321
+ if to not in ['numpy', 'torch']:
322
+ raise ValueError(f"Unsupported format: '{to}'.")
323
+
324
+ metadata = self.metadata(name)
325
+ layout = metadata.layout
326
+
327
+ if layout == "dense":
328
+ # Optimization: Try zero-copy for raw tensors
329
+ name_bytes = name.encode('utf-8')
330
+ view_ptr = ffi.NULL
331
+ is_zero_copy = False
332
+
333
+ if metadata.encoding == 'raw':
334
+ # Try to get a zero-copy slice
335
+ try:
336
+ view_ptr = lib.ztensor_reader_get_tensor_slice(self._ptr, name_bytes)
337
+ except AttributeError:
338
+ # Fallback if library doesn't have the function yet (e.g. old build)
339
+ pass
340
+
341
+ if view_ptr != ffi.NULL:
342
+ is_zero_copy = True
343
+
344
+ if not is_zero_copy:
345
+ # Fallback to standard read (copy)
346
+ view_ptr = lib.ztensor_reader_read_tensor(self._ptr, name_bytes, 1 if verify_checksum else 0)
347
+ _check_ptr(view_ptr, f"read_tensor: {name}")
348
+
349
+ view_ptr = ffi.gc(view_ptr, lib.ztensor_free_tensor_view)
350
+
351
+ if to == 'numpy':
352
+ arr = _ZTensorView(
353
+ buffer=ffi.buffer(view_ptr.data, view_ptr.len),
354
+ dtype=metadata.dtype,
355
+ shape=metadata.shape,
356
+ view_ptr=view_ptr
357
+ )
358
+ if is_zero_copy:
359
+ arr._reader_ref = self
360
+ return arr
361
+
362
+ elif to == 'torch':
363
+ if not TORCH_AVAILABLE: raise ZTensorError("PyTorch not installed.")
364
+ torch_dtype = DTYPE_ZT_TO_TORCH.get(metadata.dtype_str)
365
+ buffer = ffi.buffer(view_ptr.data, view_ptr.len)
366
+
367
+ # Note: frombuffer normally shares memory.
368
+ # If the buffer is read-only (which mmap might be), PyTorch might copy if it needs mutable tensor.
369
+ # zTensor mmap is read-only usually, so this should be fine for inference.
370
+ torch_tensor = _torch.frombuffer(buffer, dtype=torch_dtype).reshape(metadata.shape)
371
+ torch_tensor._owner = view_ptr
372
+ if is_zero_copy:
373
+ torch_tensor._reader_ref = self
374
+ return torch_tensor
375
+
376
+ elif layout == "sparse_csr":
377
+ vals, v_ref = self._read_component(name, "values", lambda: metadata.dtype)
378
+ idxs, i_ref = self._read_component(name, "indices", lambda: np.uint64)
379
+ ptrs, p_ref = self._read_component(name, "indptr", lambda: np.uint64)
380
+
381
+ if to == 'numpy':
382
+ try:
383
+ from scipy.sparse import csr_matrix
384
+ except ImportError:
385
+ raise ZTensorError("scipy is required for reading sparse tensors as numpy.")
386
+ # Create CSR matrix with zero-copy views
387
+ result = csr_matrix((vals, idxs.astype(np.int32), ptrs.astype(np.int32)), shape=metadata.shape)
388
+ # Keep FFI view pointers alive for the lifetime of the sparse matrix
389
+ result._ztensor_owners = (v_ref, i_ref, p_ref)
390
+ return result
391
+
392
+ elif to == 'torch':
393
+ if not TORCH_AVAILABLE: raise ZTensorError("No Torch.")
394
+ # PyTorch sparse tensors copy internally, so we need to copy data
395
+ t_vals = _torch.from_numpy(vals.copy())
396
+ t_indptr = _torch.from_numpy(ptrs.astype(np.int64))
397
+ t_indices = _torch.from_numpy(idxs.astype(np.int64))
398
+ return _torch.sparse_csr_tensor(t_indptr, t_indices, t_vals, size=metadata.shape)
399
+
400
+ elif layout == "sparse_coo":
401
+ vals, v_ref = self._read_component(name, "values", lambda: metadata.dtype)
402
+ coords, c_ref = self._read_component(name, "coords", lambda: np.uint64)
403
+
404
+ nnz = vals.shape[0]
405
+ ndim = len(metadata.shape)
406
+ coords = coords.reshape((ndim, nnz))
407
+
408
+ if to == 'numpy':
409
+ if ndim != 2: raise ZTensorError("Scipy COO only supports 2D.")
410
+ from scipy.sparse import coo_matrix
411
+ # Create COO matrix with zero-copy views
412
+ result = coo_matrix((vals, (coords[0], coords[1])), shape=metadata.shape)
413
+ # Keep FFI view pointers alive for the lifetime of the sparse matrix
414
+ result._ztensor_owners = (v_ref, c_ref)
415
+ return result
416
+
417
+ elif to == 'torch':
418
+ if not TORCH_AVAILABLE: raise ZTensorError("No Torch.")
419
+ # PyTorch sparse tensors copy internally
420
+ t_vals = _torch.from_numpy(vals.copy())
421
+ t_indices = _torch.from_numpy(coords.astype(np.int64))
422
+ return _torch.sparse_coo_tensor(t_indices, t_vals, size=metadata.shape)
423
+
424
+ else:
425
+ raise ZTensorError(f"Unsupported layout: {layout}")
426
+
427
+ def read_tensors(self, names: list[str], to: str = 'numpy', verify_checksum: bool = False) -> list:
428
+ """
429
+ Reads multiple tensors in batch.
430
+
431
+ Args:
432
+ names (list[str]): List of tensor names to read.
433
+ to (str): The desired output format. 'numpy' or 'torch'.
434
+ verify_checksum (bool): If True, verify checksums during read. Default: False.
435
+
436
+ Returns:
437
+ List of tensors in the same order as input names.
438
+ """
439
+ if self._ptr is None: raise ZTensorError("Reader is closed.")
440
+ if to not in ['numpy', 'torch']:
441
+ raise ValueError(f"Unsupported format: '{to}'.")
442
+
443
+ # For dense tensors, use the batch FFI function for efficiency
444
+ # For sparse tensors, fall back to individual reads
445
+ results = []
446
+ dense_names = []
447
+ dense_indices = []
448
+
449
+ # First pass: collect metadata and identify dense tensors
450
+ metadatas = []
451
+ for i, name in enumerate(names):
452
+ meta = self.metadata(name)
453
+ metadatas.append(meta)
454
+ if meta.layout == "dense":
455
+ dense_names.append(name)
456
+ dense_indices.append(i)
457
+
458
+ # Batch read all dense tensors
459
+ dense_results = {}
460
+ if dense_names:
461
+ name_ptrs = [ffi.new("char[]", n.encode('utf-8')) for n in dense_names]
462
+ name_ptr_array = ffi.new("char*[]", name_ptrs)
463
+
464
+ arr_ptr = lib.ztensor_reader_read_tensors(
465
+ self._ptr,
466
+ name_ptr_array,
467
+ len(dense_names),
468
+ 1 if verify_checksum else 0
469
+ )
470
+ _check_ptr(arr_ptr, "read_tensors batch")
471
+ # Keep arr_ptr alive - it owns all the views
472
+ arr_ptr = ffi.gc(arr_ptr, lib.ztensor_free_tensor_view_array)
473
+
474
+ for idx, name in enumerate(dense_names):
475
+ view = arr_ptr.views[idx]
476
+ meta = metadatas[dense_indices[idx]]
477
+
478
+ if to == 'numpy':
479
+ # Zero-copy: _ZTensorView keeps arr_ptr alive via view_ptr reference
480
+ arr = _ZTensorView(
481
+ buffer=ffi.buffer(view.data, view.len),
482
+ dtype=meta.dtype,
483
+ shape=meta.shape,
484
+ view_ptr=arr_ptr # Keep the batch array alive
485
+ )
486
+ # If this is a zero-copy view (owner is null), we must keep Reader alive
487
+ if view._owner == ffi.NULL:
488
+ arr._reader_ref = self
489
+ dense_results[name] = arr
490
+ elif to == 'torch':
491
+ if not TORCH_AVAILABLE: raise ZTensorError("PyTorch not installed.")
492
+ torch_dtype = DTYPE_ZT_TO_TORCH.get(meta.dtype_str)
493
+ buffer = ffi.buffer(view.data, view.len)
494
+ tensor = _torch.frombuffer(buffer, dtype=torch_dtype).reshape(meta.shape)
495
+ tensor._owner = arr_ptr # Keep the batch array alive
496
+
497
+ if view._owner == ffi.NULL:
498
+ tensor._reader_ref = self
499
+ dense_results[name] = tensor
500
+
501
+ # Assemble final results
502
+ for i, name in enumerate(names):
503
+ if name in dense_results:
504
+ results.append(dense_results[name])
505
+ else:
506
+ # Sparse tensor - fall back to individual read
507
+ results.append(self.read_tensor(name, to=to, verify_checksum=verify_checksum))
508
+
509
+ return results
510
+
511
+
512
+ class Writer:
513
+ """A Pythonic context manager for writing zTensor files."""
514
+
515
+ def __init__(self, file_path):
516
+ path_bytes = file_path.encode('utf-8')
517
+ ptr = lib.ztensor_writer_create(path_bytes)
518
+ _check_ptr(ptr, f"Writer create: {file_path}")
519
+ self._ptr = ptr
520
+ self._finalized = False
521
+
522
+ def __enter__(self):
523
+ return self
524
+
525
+ def __exit__(self, exc_type, exc_val, exc_tb):
526
+ if self._ptr and not self._finalized:
527
+ if exc_type is None:
528
+ self.finalize()
529
+ else:
530
+ lib.ztensor_writer_free(self._ptr)
531
+ self._ptr = None
532
+
533
+ def add_tensor(self, name: str, tensor, compress: Union[bool, int] = False):
534
+ """
535
+ Adds a NumPy or PyTorch tensor to the file.
536
+
537
+ Args:
538
+ name (str): The name of the tensor to add.
539
+ tensor (np.ndarray or torch.Tensor): The tensor data to write.
540
+ compress (bool or int): Compression settings.
541
+ - False / 0: No compression (Raw).
542
+ - True: Default Zstd compression (Level 3).
543
+ - int > 0: Specific Zstd compression level (e.g., 1-22).
544
+ """
545
+ if not self._ptr: raise ZTensorError("Writer is closed or finalized.")
546
+
547
+ # Resolve compression level
548
+ compression_level = 0
549
+ if compress is True:
550
+ compression_level = 3 # Default zstd level
551
+ elif compress is False or compress is None:
552
+ compression_level = 0
553
+ elif isinstance(compress, int):
554
+ compression_level = compress
555
+ else:
556
+ raise TypeError(f"Invalid type for 'compress': {type(compress)}. Expected bool or int.")
557
+
558
+ if isinstance(tensor, np.ndarray):
559
+ tensor = np.ascontiguousarray(tensor)
560
+ shape = tensor.shape
561
+ dtype_str = DTYPE_NP_TO_ZT.get(tensor.dtype)
562
+ data_ptr = ffi.cast("unsigned char*", tensor.ctypes.data)
563
+ nbytes = tensor.nbytes
564
+
565
+ elif TORCH_AVAILABLE and isinstance(tensor, _torch.Tensor):
566
+ if tensor.is_cuda:
567
+ raise ZTensorError("Cannot write directly from a CUDA tensor. Copy to CPU first using .cpu().")
568
+ tensor = tensor.contiguous()
569
+ shape = tuple(tensor.shape)
570
+ dtype_str = DTYPE_TORCH_TO_ZT.get(tensor.dtype)
571
+ data_ptr = ffi.cast("unsigned char*", tensor.data_ptr())
572
+ nbytes = tensor.numel() * tensor.element_size()
573
+
574
+ else:
575
+ supported = "np.ndarray" + (" or torch.Tensor" if TORCH_AVAILABLE else "")
576
+ raise TypeError(f"Unsupported tensor type: {type(tensor)}. Must be {supported}.")
577
+
578
+ if not dtype_str:
579
+ msg = f"Unsupported dtype: {tensor.dtype}."
580
+ if 'bfloat16' in str(tensor.dtype) and not ML_DTYPES_AVAILABLE:
581
+ msg += " For NumPy bfloat16 support, please install the 'ml_dtypes' package."
582
+ raise ZTensorError(msg)
583
+
584
+ name_bytes = name.encode('utf-8')
585
+ shape_array = np.array(shape, dtype=np.uint64)
586
+ shape_ptr = ffi.cast("uint64_t*", shape_array.ctypes.data)
587
+ dtype_bytes = dtype_str.encode('utf-8')
588
+
589
+ status = lib.ztensor_writer_add_tensor(
590
+ self._ptr, name_bytes, shape_ptr, len(shape),
591
+ dtype_bytes, data_ptr, nbytes, compression_level
592
+ )
593
+ _check_status(status, f"add_tensor: {name}")
594
+
595
+ # Keep alive: data_ptr is derived from tensor, so keeping tensor alive is enough?
596
+ # Actually ffi.cast doesn't keep alive. We need to ensure 'tensor' stays alive until after the call.
597
+ # Python arguments are kept alive during the call, so this is safe.
598
+
599
+ def add_sparse_csr(self, name: str, values, indices, indptr, shape):
600
+ """Adds a sparse CSR tensor."""
601
+ if not self._ptr: raise ZTensorError("Writer is closed.")
602
+
603
+ # Keep all arrays alive during FFI call
604
+ keepalive = []
605
+
606
+ def get_buffer_info(arr, force_dtype=None):
607
+ if isinstance(arr, np.ndarray):
608
+ if force_dtype is not None and arr.dtype != force_dtype:
609
+ arr = arr.astype(force_dtype)
610
+ arr = np.ascontiguousarray(arr)
611
+ keepalive.append(arr) # Prevent GC
612
+ ptr = ffi.cast("unsigned char*", arr.ctypes.data)
613
+ length = arr.nbytes
614
+ count = arr.size
615
+ dtype_str = DTYPE_NP_TO_ZT.get(arr.dtype)
616
+ return ptr, length, count, dtype_str
617
+ elif TORCH_AVAILABLE and isinstance(arr, _torch.Tensor):
618
+ if arr.is_cuda: raise ZTensorError("CUDA tensors not supported. Move to CPU.")
619
+ arr = arr.contiguous()
620
+ keepalive.append(arr) # Prevent GC
621
+ ptr = ffi.cast("unsigned char*", arr.data_ptr())
622
+ length = arr.numel() * arr.element_size()
623
+ count = arr.numel()
624
+ dtype_str = DTYPE_TORCH_TO_ZT.get(arr.dtype)
625
+ return ptr, length, count, dtype_str
626
+ else:
627
+ raise TypeError(f"Unsupported array type: {type(arr)}")
628
+
629
+ # Get values as bytes
630
+ v_ptr, v_len, _, v_dtype = get_buffer_info(values)
631
+
632
+ # Convert indices and indptr to uint64
633
+ i_ptr, _, i_cnt, _ = get_buffer_info(indices, force_dtype=np.uint64)
634
+ p_ptr, _, p_cnt, _ = get_buffer_info(indptr, force_dtype=np.uint64)
635
+
636
+ # Cast to u64* for FFI
637
+ i_ptr_u64 = ffi.cast("uint64_t*", ffi.cast("unsigned char*", i_ptr))
638
+ p_ptr_u64 = ffi.cast("uint64_t*", ffi.cast("unsigned char*", p_ptr))
639
+
640
+ name_bytes = name.encode('utf-8')
641
+ shape_array = np.array(shape, dtype=np.uint64)
642
+ keepalive.append(shape_array)
643
+ shape_ptr = ffi.cast("uint64_t*", shape_array.ctypes.data)
644
+ dtype_bytes = v_dtype.encode('utf-8')
645
+
646
+ status = lib.ztensor_writer_add_sparse_csr(
647
+ self._ptr,
648
+ name_bytes,
649
+ shape_ptr, len(shape),
650
+ dtype_bytes,
651
+ v_ptr, v_len,
652
+ i_ptr_u64, i_cnt,
653
+ p_ptr_u64, p_cnt,
654
+ )
655
+ _check_status(status, f"add_sparse_csr: {name}")
656
+
657
+ def add_sparse_coo(self, name: str, values, indices, shape):
658
+ """Adds a sparse COO tensor."""
659
+ if not self._ptr: raise ZTensorError("Writer is closed.")
660
+
661
+ def get_buffer_info(arr):
662
+ if isinstance(arr, np.ndarray):
663
+ arr = np.ascontiguousarray(arr)
664
+ return arr, ffi.cast("unsigned char*", arr.ctypes.data), arr.nbytes, DTYPE_NP_TO_ZT.get(arr.dtype)
665
+ elif TORCH_AVAILABLE and isinstance(arr, _torch.Tensor):
666
+ if arr.is_cuda: raise ZTensorError("No CUDA")
667
+ arr = arr.contiguous()
668
+ return arr, ffi.cast("unsigned char*", arr.data_ptr()), arr.numel() * arr.element_size(), DTYPE_TORCH_TO_ZT.get(arr.dtype)
669
+ raise TypeError("Unsupported")
670
+
671
+ v_arr, v_ptr, v_len, v_dtype = get_buffer_info(values)
672
+
673
+ def ensure_u64_ptr(arr):
674
+ if isinstance(arr, np.ndarray):
675
+ if arr.dtype != np.uint64: arr = arr.astype(np.uint64)
676
+ arr = np.ascontiguousarray(arr)
677
+ return ffi.cast("uint64_t*", arr.ctypes.data), arr, arr.size
678
+ elif TORCH_AVAILABLE and isinstance(arr, _torch.Tensor):
679
+ if arr.dtype != _torch.int64: arr = arr.to(_torch.int64)
680
+ arr = arr.contiguous()
681
+ return ffi.cast("uint64_t*", arr.data_ptr()), arr, arr.numel()
682
+ raise TypeError("Unsupported")
683
+
684
+ i_ptr_u64, _i_keep, i_count = ensure_u64_ptr(indices)
685
+
686
+ name_bytes = name.encode('utf-8')
687
+ shape_array = np.array(shape, dtype=np.uint64)
688
+ shape_ptr = ffi.cast("uint64_t*", shape_array.ctypes.data)
689
+ dtype_bytes = v_dtype.encode('utf-8')
690
+
691
+ status = lib.ztensor_writer_add_sparse_coo(
692
+ self._ptr,
693
+ name_bytes,
694
+ shape_ptr, len(shape),
695
+ dtype_bytes,
696
+ v_ptr, v_len,
697
+ i_ptr_u64, i_count
698
+ )
699
+ _check_status(status, f"add_sparse_coo: {name}")
700
+
701
+ def finalize(self):
702
+ """Finalizes the zTensor file, writing the metadata index."""
703
+ if not self._ptr: raise ZTensorError("Writer is already closed or finalized.")
704
+ status = lib.ztensor_writer_finalize(self._ptr)
705
+ self._ptr = None
706
+ self._finalized = True
707
+ _check_status(status, "finalize")
708
+
709
+
710
+ __all__ = ["Reader", "Writer", "TensorMetadata", "ZTensorError"]