ztensor 0.1.2__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ztensor might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ztensor
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: MIT License
@@ -10,7 +10,7 @@ build-backend = "maturin"
10
10
  [project]
11
11
  # The name of your package on PyPI.
12
12
  name = "ztensor"
13
- version = "0.1.2"
13
+ version = "0.1.4"
14
14
  description = "Python bindings for the zTensor library."
15
15
  readme = "README.md" # It's good practice to have a README.
16
16
  authors = [
@@ -0,0 +1,171 @@
1
+ import unittest
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ import numpy as np
6
+
7
+ # --- Conditional PyTorch Import ---
8
+ try:
9
+ import torch
10
+
11
+ TORCH_AVAILABLE = True
12
+ except ImportError:
13
+ TORCH_AVAILABLE = False
14
+
15
+ # --- Import the ztensor wrapper ---
16
+ # Assuming the bindings file is named 'bindings.py' in a 'ztensor' package/directory
17
+ from ztensor import Reader, Writer, ZTensorError, TensorMetadata
18
+
19
+
20
+ class TestZTensorBindings(unittest.TestCase):
21
+
22
+ def setUp(self):
23
+ """Set up a temporary directory for test files."""
24
+ self.test_dir = tempfile.mkdtemp()
25
+ self.test_file = os.path.join(self.test_dir, "test.zt")
26
+
27
+ def tearDown(self):
28
+ """Remove the temporary directory and its contents."""
29
+ shutil.rmtree(self.test_dir)
30
+
31
+ def test_01_writer_and_reader_numpy(self):
32
+ """Test writing and reading a single NumPy tensor."""
33
+ tensor_a = np.arange(24, dtype=np.float32).reshape(2, 3, 4)
34
+ tensor_b = np.array([[True, False], [False, True]], dtype=bool)
35
+
36
+ with Writer(self.test_file) as writer:
37
+ writer.add_tensor("tensor_a", tensor_a)
38
+ writer.add_tensor("tensor_b", tensor_b)
39
+
40
+ self.assertTrue(os.path.exists(self.test_file))
41
+
42
+ with Reader(self.test_file) as reader:
43
+ read_a = reader.read_tensor("tensor_a")
44
+ read_b = reader.read_tensor("tensor_b")
45
+
46
+ self.assertIsInstance(read_a, np.ndarray)
47
+ self.assertTrue(np.array_equal(tensor_a, read_a))
48
+ self.assertEqual(tensor_a.dtype, read_a.dtype)
49
+ self.assertEqual(tensor_a.shape, read_a.shape)
50
+
51
+ self.assertIsInstance(read_b, np.ndarray)
52
+ self.assertTrue(np.array_equal(tensor_b, read_b))
53
+
54
+ @unittest.skipIf(not TORCH_AVAILABLE, "PyTorch not installed")
55
+ def test_02_writer_and_reader_torch(self):
56
+ """Test writing and reading a PyTorch tensor."""
57
+ tensor_a = torch.randn(10, 20, dtype=torch.float16)
58
+ tensor_b = torch.randint(0, 255, (100,), dtype=torch.uint8)
59
+
60
+ with Writer(self.test_file) as writer:
61
+ writer.add_tensor("torch_a", tensor_a)
62
+ writer.add_tensor("torch_b", tensor_b)
63
+
64
+ with Reader(self.test_file) as reader:
65
+ # Read back as torch tensor
66
+ read_a_torch = reader.read_tensor("torch_a", to='torch')
67
+ self.assertIsInstance(read_a_torch, torch.Tensor)
68
+ self.assertTrue(torch.equal(tensor_a, read_a_torch))
69
+ self.assertEqual(tensor_a.dtype, read_a_torch.dtype)
70
+
71
+ # Read back as numpy array
72
+ read_b_np = reader.read_tensor("torch_b", to='numpy')
73
+ self.assertIsInstance(read_b_np, np.ndarray)
74
+ self.assertTrue(np.array_equal(tensor_b.numpy(), read_b_np))
75
+
76
+ def test_03_metadata_access_and_iteration(self):
77
+ """Test the reader's container and metadata features."""
78
+ tensors = {
79
+ "tensor_int": np.ones((5, 5), dtype=np.int32),
80
+ "tensor_float": np.zeros((10,), dtype=np.float64),
81
+ "scalar": np.array(3.14, dtype=np.float32)
82
+ }
83
+ tensor_names = sorted(tensors.keys())
84
+
85
+ with Writer(self.test_file) as writer:
86
+ for name, tensor in tensors.items():
87
+ writer.add_tensor(name, tensor)
88
+
89
+ with Reader(self.test_file) as reader:
90
+ # Test __len__
91
+ self.assertEqual(len(reader), 3)
92
+
93
+ # Test get_tensor_names
94
+ read_names = sorted(reader.get_tensor_names())
95
+ self.assertEqual(tensor_names, read_names)
96
+
97
+ # Test __iter__ and __getitem__
98
+ all_meta = []
99
+ for i in range(len(reader)):
100
+ meta = reader[i]
101
+ self.assertIsInstance(meta, TensorMetadata)
102
+ all_meta.append(meta.name)
103
+ self.assertEqual(tensor_names, sorted(all_meta))
104
+
105
+ # Test list_tensors
106
+ self.assertEqual(len(reader.list_tensors()), 3)
107
+
108
+ # Test specific metadata properties
109
+ meta_scalar = reader.get_metadata("scalar")
110
+ self.assertEqual(meta_scalar.name, "scalar")
111
+ self.assertEqual(meta_scalar.shape, (1, ))
112
+ self.assertEqual(meta_scalar.dtype, np.dtype('float32'))
113
+ self.assertEqual(meta_scalar.dtype_str, 'float32')
114
+ self.assertGreater(meta_scalar.offset, 0)
115
+ self.assertGreater(meta_scalar.size, 0)
116
+ self.assertEqual(meta_scalar.layout, "dense")
117
+ self.assertEqual(meta_scalar.encoding, "raw")
118
+ self.assertIn(meta_scalar.endianness, ["little", "big"])
119
+ self.assertIsNone(meta_scalar.checksum)
120
+
121
+ def test_04_error_handling(self):
122
+ """Test expected failure modes."""
123
+ tensor = np.arange(10)
124
+ with Writer(self.test_file) as writer:
125
+ writer.add_tensor("my_tensor", tensor)
126
+
127
+ # Test reading non-existent tensor
128
+ with Reader(self.test_file) as reader:
129
+ with self.assertRaisesRegex(ZTensorError, "Tensor not found"):
130
+ reader.read_tensor("non_existent_tensor")
131
+
132
+ # Test accessing closed reader
133
+ reader = Reader(self.test_file)
134
+ reader.__exit__(None, None, None) # Manually close
135
+ with self.assertRaisesRegex(ZTensorError, "Reader is closed"):
136
+ len(reader)
137
+ with self.assertRaisesRegex(ZTensorError, "Reader is closed"):
138
+ reader.read_tensor("my_tensor")
139
+
140
+ # Test index out of range
141
+ with Reader(self.test_file) as reader:
142
+ with self.assertRaises(IndexError):
143
+ reader[99]
144
+
145
+ # Test writing to a finalized writer
146
+ writer = Writer(self.test_file)
147
+ writer.add_tensor("t1", tensor)
148
+ writer.finalize()
149
+ with self.assertRaisesRegex(ZTensorError, "Writer is closed or finalized."):
150
+ writer.add_tensor("t2", tensor)
151
+
152
+ def test_05_writer_context_manager_exception(self):
153
+ """Ensure writer does not finalize on error inside `with` block."""
154
+ bad_file_path = os.path.join(self.test_dir, "bad_file.zt")
155
+ try:
156
+ with Writer(bad_file_path) as writer:
157
+ writer.add_tensor("good_tensor", np.arange(5))
158
+ raise ValueError("Simulating an error during writing")
159
+ except ValueError:
160
+ pass # Expected
161
+
162
+ # The file should exist but be empty/invalid because finalize was not called.
163
+ self.assertTrue(os.path.exists(bad_file_path))
164
+ # Attempting to read should fail
165
+ with self.assertRaises(ZTensorError):
166
+ with Reader(bad_file_path) as reader:
167
+ len(reader)
168
+
169
+
170
+ if __name__ == '__main__':
171
+ unittest.main(verbosity=2)
@@ -92,12 +92,23 @@ class TensorMetadata:
92
92
  def __init__(self, meta_ptr):
93
93
  self._ptr = ffi.gc(meta_ptr, lib.ztensor_metadata_free)
94
94
  _check_ptr(self._ptr, "TensorMetadata constructor")
95
+ # Cache for properties to avoid repeated FFI calls
95
96
  self._name = None
96
97
  self._dtype_str = None
97
98
  self._shape = None
99
+ self._offset = None
100
+ self._size = None
101
+ self._layout = None
102
+ self._encoding = None
103
+ self._endianness = "not_checked"
104
+ self._checksum = "not_checked"
105
+
106
+ def __repr__(self):
107
+ return f"<TensorMetadata name='{self.name}' shape={self.shape} dtype='{self.dtype_str}'>"
98
108
 
99
109
  @property
100
110
  def name(self):
111
+ """The name of the tensor."""
101
112
  if self._name is None:
102
113
  name_ptr = lib.ztensor_metadata_get_name(self._ptr)
103
114
  _check_ptr(name_ptr, "get_name")
@@ -107,6 +118,7 @@ class TensorMetadata:
107
118
 
108
119
  @property
109
120
  def dtype_str(self):
121
+ """The zTensor dtype string (e.g., 'float32')."""
110
122
  if self._dtype_str is None:
111
123
  dtype_ptr = lib.ztensor_metadata_get_dtype_str(self._ptr)
112
124
  _check_ptr(dtype_ptr, "get_dtype_str")
@@ -116,7 +128,7 @@ class TensorMetadata:
116
128
 
117
129
  @property
118
130
  def dtype(self):
119
- """Returns the numpy dtype for this tensor."""
131
+ """The numpy dtype for this tensor."""
120
132
  dtype_str = self.dtype_str
121
133
  dt = DTYPE_ZT_TO_NP.get(dtype_str)
122
134
  if dt is None:
@@ -130,6 +142,7 @@ class TensorMetadata:
130
142
 
131
143
  @property
132
144
  def shape(self):
145
+ """The shape of the tensor as a tuple."""
133
146
  if self._shape is None:
134
147
  shape_len = lib.ztensor_metadata_get_shape_len(self._ptr)
135
148
  if shape_len > 0:
@@ -141,6 +154,64 @@ class TensorMetadata:
141
154
  self._shape = tuple()
142
155
  return self._shape
143
156
 
157
+ @property
158
+ def offset(self):
159
+ """The on-disk offset of the tensor data in bytes."""
160
+ if self._offset is None:
161
+ self._offset = lib.ztensor_metadata_get_offset(self._ptr)
162
+ return self._offset
163
+
164
+ @property
165
+ def size(self):
166
+ """The on-disk size of the tensor data in bytes (can be compressed size)."""
167
+ if self._size is None:
168
+ self._size = lib.ztensor_metadata_get_size(self._ptr)
169
+ return self._size
170
+
171
+ @property
172
+ def layout(self):
173
+ """The tensor layout as a string (e.g., 'dense')."""
174
+ if self._layout is None:
175
+ layout_ptr = lib.ztensor_metadata_get_layout_str(self._ptr)
176
+ _check_ptr(layout_ptr, "get_layout_str")
177
+ self._layout = ffi.string(layout_ptr).decode('utf-8')
178
+ lib.ztensor_free_string(layout_ptr)
179
+ return self._layout
180
+
181
+ @property
182
+ def encoding(self):
183
+ """The tensor encoding as a string (e.g., 'raw', 'zstd')."""
184
+ if self._encoding is None:
185
+ encoding_ptr = lib.ztensor_metadata_get_encoding_str(self._ptr)
186
+ _check_ptr(encoding_ptr, "get_encoding_str")
187
+ self._encoding = ffi.string(encoding_ptr).decode('utf-8')
188
+ lib.ztensor_free_string(encoding_ptr)
189
+ return self._encoding
190
+
191
+ @property
192
+ def endianness(self):
193
+ """The data endianness ('little', 'big') if applicable, else None."""
194
+ if self._endianness == "not_checked":
195
+ endian_ptr = lib.ztensor_metadata_get_data_endianness_str(self._ptr)
196
+ if endian_ptr == ffi.NULL:
197
+ self._endianness = None
198
+ else:
199
+ self._endianness = ffi.string(endian_ptr).decode('utf-8')
200
+ lib.ztensor_free_string(endian_ptr)
201
+ return self._endianness
202
+
203
+ @property
204
+ def checksum(self):
205
+ """The checksum string if present, else None."""
206
+ if self._checksum == "not_checked":
207
+ checksum_ptr = lib.ztensor_metadata_get_checksum_str(self._ptr)
208
+ if checksum_ptr == ffi.NULL:
209
+ self._checksum = None
210
+ else:
211
+ self._checksum = ffi.string(checksum_ptr).decode('utf-8')
212
+ lib.ztensor_free_string(checksum_ptr)
213
+ return self._checksum
214
+
144
215
 
145
216
  class Reader:
146
217
  """A Pythonic context manager for reading zTensor files."""
@@ -157,6 +228,39 @@ class Reader:
157
228
  def __exit__(self, exc_type, exc_val, exc_tb):
158
229
  self._ptr = None
159
230
 
231
+ def __len__(self):
232
+ """Returns the number of tensors in the file."""
233
+ if self._ptr is None: raise ZTensorError("Reader is closed.")
234
+ return lib.ztensor_reader_get_metadata_count(self._ptr)
235
+
236
+ def __iter__(self):
237
+ """Iterates over the metadata of all tensors in the file."""
238
+ if self._ptr is None: raise ZTensorError("Reader is closed.")
239
+ for i in range(len(self)):
240
+ yield self[i]
241
+
242
+ def __getitem__(self, index: int) -> TensorMetadata:
243
+ """Retrieves metadata for a tensor by its index."""
244
+ if self._ptr is None: raise ZTensorError("Reader is closed.")
245
+ if index >= len(self):
246
+ raise IndexError("Tensor index out of range")
247
+ meta_ptr = lib.ztensor_reader_get_metadata_by_index(self._ptr, index)
248
+ _check_ptr(meta_ptr, f"get_metadata_by_index: {index}")
249
+ return TensorMetadata(meta_ptr)
250
+
251
+ def list_tensors(self) -> list[TensorMetadata]:
252
+ """Returns a list of all TensorMetadata objects in the file."""
253
+ return list(self)
254
+
255
+ def get_tensor_names(self) -> list[str]:
256
+ """Returns a list of all tensor names in the file."""
257
+ if self._ptr is None: raise ZTensorError("Reader is closed.")
258
+ c_array_ptr = lib.ztensor_reader_get_all_tensor_names(self._ptr)
259
+ _check_ptr(c_array_ptr, "get_all_tensor_names")
260
+ c_array_ptr = ffi.gc(c_array_ptr, lib.ztensor_free_string_array)
261
+
262
+ return [ffi.string(c_array_ptr.strings[i]).decode('utf-8') for i in range(c_array_ptr.len)]
263
+
160
264
  def get_metadata(self, name: str) -> TensorMetadata:
161
265
  """Retrieves metadata for a tensor by its name."""
162
266
  if self._ptr is None: raise ZTensorError("Reader is closed.")
@@ -177,6 +281,7 @@ class Reader:
177
281
  Returns:
178
282
  np.ndarray or torch.Tensor: The tensor data.
179
283
  """
284
+ if self._ptr is None: raise ZTensorError("Reader is closed.")
180
285
  if to not in ['numpy', 'torch']:
181
286
  raise ValueError(f"Unsupported format: '{to}'. Choose 'numpy' or 'torch'.")
182
287
 
@@ -238,6 +343,7 @@ class Writer:
238
343
  if exc_type is None:
239
344
  self.finalize()
240
345
  else:
346
+ # If an error occurred, just free the handle without finalizing
241
347
  lib.ztensor_writer_free(self._ptr)
242
348
  self._ptr = None
243
349
 
@@ -294,7 +400,7 @@ class Writer:
294
400
  """Finalizes the zTensor file, writing the metadata index."""
295
401
  if not self._ptr: raise ZTensorError("Writer is already closed or finalized.")
296
402
  status = lib.ztensor_writer_finalize(self._ptr)
297
- self._ptr = None
403
+ self._ptr = None # The writer is consumed in Rust
298
404
  self._finalized = True
299
405
  _check_status(status, "finalize")
300
406
 
@@ -12,7 +12,35 @@ use crate::models::{ChecksumAlgorithm, DType, DataEndianness, Encoding, Layout,
12
12
  use crate::reader::ZTensorReader;
13
13
  use crate::writer::ZTensorWriter;
14
14
 
15
+ // --- C-Compatible Structs & Handles ---
16
+
17
+ /// Opaque handle to a file-based zTensor reader.
18
+ pub type CZTensorReader = ZTensorReader<BufReader<std::fs::File>>;
19
+ /// Opaque handle to a file-based zTensor writer.
20
+ pub type CZTensorWriter = ZTensorWriter<BufWriter<std::fs::File>>;
21
+ /// Opaque handle to an in-memory zTensor writer.
22
+ pub type CInMemoryZTensorWriter = ZTensorWriter<Cursor<Vec<u8>>>;
23
+ /// Opaque, owned handle to a tensor's metadata.
24
+ pub type CTensorMetadata = TensorMetadata;
25
+
26
+ /// A self-contained, C-compatible view of owned tensor data.
27
+ #[repr(C)]
28
+ pub struct CTensorDataView {
29
+ pub data: *const c_uchar,
30
+ pub len: size_t,
31
+ // Private field holding the owned Vec<u8> data.
32
+ _owner: *mut c_void,
33
+ }
34
+
35
+ /// A C-compatible, heap-allocated array of C strings.
36
+ #[repr(C)]
37
+ pub struct CStringArray {
38
+ pub strings: *mut *mut c_char,
39
+ pub len: size_t,
40
+ }
41
+
15
42
  // --- Error Handling ---
43
+
16
44
  lazy_static! {
17
45
  static ref LAST_ERROR: Mutex<Option<CString>> = Mutex::new(None);
18
46
  }
@@ -23,6 +51,10 @@ fn update_last_error(err: ZTensorError) {
23
51
  *LAST_ERROR.lock().unwrap() = Some(msg);
24
52
  }
25
53
 
54
+ /// Retrieves the last error message set by a failed API call.
55
+ ///
56
+ /// The returned string is valid until the next API call.
57
+ /// Returns `null` if no error has occurred.
26
58
  #[unsafe(no_mangle)]
27
59
  pub extern "C" fn ztensor_last_error_message() -> *const c_char {
28
60
  match LAST_ERROR.lock().unwrap().as_ref() {
@@ -31,30 +63,51 @@ pub extern "C" fn ztensor_last_error_message() -> *const c_char {
31
63
  }
32
64
  }
33
65
 
34
- // --- Opaque Structs and Handles ---
35
-
36
- // Reader is now generic over R: Read + Seek.
37
- // For the C API, we'll expose a version that works with files.
38
- pub type CZTensorReader = ZTensorReader<BufReader<std::fs::File>>;
39
-
40
- // Writer is also generic. We'll expose a file-based and in-memory version.
41
- pub type CZTensorWriter = ZTensorWriter<BufWriter<std::fs::File>>;
42
- pub type CInMemoryZTensorWriter = ZTensorWriter<Cursor<Vec<u8>>>;
66
+ // --- Internal Helpers ---
43
67
 
44
- // Represents an owned, C-compatible TensorMetadata.
45
- pub type CTensorMetadata = TensorMetadata;
68
+ /// A macro to safely access the Rust object behind an opaque C pointer.
69
+ /// It checks for null and returns a safe reference, handling the error case.
70
+ macro_rules! ztensor_handle {
71
+ ($ptr:expr) => {
72
+ if $ptr.is_null() {
73
+ update_last_error(ZTensorError::Other("Null pointer passed as handle".into()));
74
+ return ptr::null_mut();
75
+ } else {
76
+ unsafe { &*$ptr }
77
+ }
78
+ };
79
+ (mut $ptr:expr) => {
80
+ if $ptr.is_null() {
81
+ update_last_error(ZTensorError::Other("Null pointer passed as handle".into()));
82
+ return ptr::null_mut();
83
+ } else {
84
+ unsafe { &mut *$ptr }
85
+ }
86
+ };
87
+ ($ptr:expr, $err_ret:expr) => {
88
+ if $ptr.is_null() {
89
+ update_last_error(ZTensorError::Other("Null pointer passed as handle".into()));
90
+ return $err_ret;
91
+ } else {
92
+ unsafe { &*$ptr }
93
+ }
94
+ };
95
+ (mut $ptr:expr, $err_ret:expr) => {
96
+ if $ptr.is_null() {
97
+ update_last_error(ZTensorError::Other("Null pointer passed as handle".into()));
98
+ return $err_ret;
99
+ } else {
100
+ unsafe { &mut *$ptr }
101
+ }
102
+ };
103
+ }
46
104
 
47
- // Represents a non-owning, zero-copy view of tensor data.
48
- #[repr(C)]
49
- pub struct CTensorDataView {
50
- pub data: *const c_uchar,
51
- pub len: size_t,
52
- // Private field to hold the owned data, making this struct self-contained.
53
- // The C side only sees a pointer to the data, but this struct owns the Vec.
54
- _owner: *mut c_void,
105
+ /// Helper to convert a Rust String into a C-style, null-terminated string pointer.
106
+ fn to_cstring(s: String) -> *mut c_char {
107
+ CString::new(s).map_or(ptr::null_mut(), |cs| cs.into_raw())
55
108
  }
56
109
 
57
- // --- Reader Functions ---
110
+ // --- Reader API ---
58
111
 
59
112
  #[unsafe(no_mangle)]
60
113
  pub extern "C" fn ztensor_reader_open(path_str: *const c_char) -> *mut CZTensorReader {
@@ -62,11 +115,8 @@ pub extern "C" fn ztensor_reader_open(path_str: *const c_char) -> *mut CZTensorR
62
115
  update_last_error(ZTensorError::Other("Null path provided".into()));
63
116
  return ptr::null_mut();
64
117
  }
65
- let path = unsafe { CStr::from_ptr(path_str) };
66
- let path_res = path.to_str().map(Path::new);
67
-
68
- let path = match path_res {
69
- Ok(p) => p,
118
+ let path = match unsafe { CStr::from_ptr(path_str).to_str() } {
119
+ Ok(s) => Path::new(s),
70
120
  Err(_) => {
71
121
  update_last_error(ZTensorError::Other("Invalid UTF-8 path".into()));
72
122
  return ptr::null_mut();
@@ -82,22 +132,9 @@ pub extern "C" fn ztensor_reader_open(path_str: *const c_char) -> *mut CZTensorR
82
132
  }
83
133
  }
84
134
 
85
- #[unsafe(no_mangle)]
86
- pub extern "C" fn ztensor_reader_free(reader_ptr: *mut CZTensorReader) {
87
- if !reader_ptr.is_null() {
88
- unsafe {
89
- let _ = Box::from_raw(reader_ptr);
90
- };
91
- }
92
- }
93
-
94
135
  #[unsafe(no_mangle)]
95
136
  pub extern "C" fn ztensor_reader_get_metadata_count(reader_ptr: *const CZTensorReader) -> size_t {
96
- let reader = unsafe {
97
- reader_ptr
98
- .as_ref()
99
- .expect("Null pointer passed to ztensor_reader_get_metadata_count")
100
- };
137
+ let reader = ztensor_handle!(reader_ptr, 0);
101
138
  reader.list_tensors().len()
102
139
  }
103
140
 
@@ -106,11 +143,7 @@ pub extern "C" fn ztensor_reader_get_metadata_by_name(
106
143
  reader_ptr: *const CZTensorReader,
107
144
  name_str: *const c_char,
108
145
  ) -> *mut CTensorMetadata {
109
- let reader = unsafe {
110
- reader_ptr
111
- .as_ref()
112
- .expect("Null pointer passed to ztensor_reader_get_metadata_by_name")
113
- };
146
+ let reader = ztensor_handle!(reader_ptr);
114
147
  if name_str.is_null() {
115
148
  update_last_error(ZTensorError::Other("Null name pointer provided".into()));
116
149
  return ptr::null_mut();
@@ -126,24 +159,53 @@ pub extern "C" fn ztensor_reader_get_metadata_by_name(
126
159
  }
127
160
  }
128
161
 
129
- // --- Zero-Copy Tensor Data Reading ---
162
+ #[unsafe(no_mangle)]
163
+ pub extern "C" fn ztensor_reader_get_metadata_by_index(
164
+ reader_ptr: *const CZTensorReader,
165
+ index: size_t,
166
+ ) -> *mut CTensorMetadata {
167
+ let reader = ztensor_handle!(reader_ptr);
168
+ match reader.list_tensors().get(index) {
169
+ Some(metadata) => Box::into_raw(Box::new(metadata.clone())),
170
+ None => {
171
+ update_last_error(ZTensorError::Other(format!(
172
+ "Index {} is out of bounds for tensor list of length {}",
173
+ index,
174
+ reader.list_tensors().len()
175
+ )));
176
+ ptr::null_mut()
177
+ }
178
+ }
179
+ }
180
+
181
+ #[unsafe(no_mangle)]
182
+ pub extern "C" fn ztensor_reader_get_all_tensor_names(
183
+ reader_ptr: *const CZTensorReader,
184
+ ) -> *mut CStringArray {
185
+ let reader = ztensor_handle!(reader_ptr);
186
+ let names: Vec<CString> = reader
187
+ .list_tensors()
188
+ .iter()
189
+ .map(|m| CString::new(m.name.as_str()).unwrap())
190
+ .collect();
191
+
192
+ let mut c_names: Vec<*mut c_char> = names.into_iter().map(|s| s.into_raw()).collect();
193
+ let string_array = Box::new(CStringArray {
194
+ strings: c_names.as_mut_ptr(),
195
+ len: c_names.len(),
196
+ });
197
+
198
+ std::mem::forget(c_names); // C side is now responsible for this memory
199
+ Box::into_raw(string_array)
200
+ }
130
201
 
131
- /// Reads tensor data into a view without copying. The view must be freed.
132
202
  #[unsafe(no_mangle)]
133
203
  pub extern "C" fn ztensor_reader_read_tensor_view(
134
204
  reader_ptr: *mut CZTensorReader,
135
205
  metadata_ptr: *const CTensorMetadata,
136
206
  ) -> *mut CTensorDataView {
137
- let reader = unsafe {
138
- reader_ptr
139
- .as_mut()
140
- .expect("Null reader pointer to read_tensor_view")
141
- };
142
- let metadata = unsafe {
143
- metadata_ptr
144
- .as_ref()
145
- .expect("Null metadata pointer to read_tensor_view")
146
- };
207
+ let reader = ztensor_handle!(mut reader_ptr);
208
+ let metadata = ztensor_handle!(metadata_ptr);
147
209
 
148
210
  match reader.read_raw_tensor_data(metadata) {
149
211
  Ok(data_vec) => {
@@ -161,84 +223,7 @@ pub extern "C" fn ztensor_reader_read_tensor_view(
161
223
  }
162
224
  }
163
225
 
164
- /// Frees the tensor data view.
165
- #[unsafe(no_mangle)]
166
- pub extern "C" fn ztensor_free_tensor_view(view_ptr: *mut CTensorDataView) {
167
- if !view_ptr.is_null() {
168
- unsafe {
169
- let view = Box::from_raw(view_ptr);
170
- // This will drop the Vec<u8> that `_owner` points to
171
- let _ = Box::from_raw(view._owner as *mut Vec<u8>);
172
- }
173
- }
174
- }
175
-
176
- // --- Metadata Accessors ---
177
- // (These are largely the same but ensured to be safe)
178
-
179
- #[unsafe(no_mangle)]
180
- pub extern "C" fn ztensor_metadata_free(metadata_ptr: *mut CTensorMetadata) {
181
- if !metadata_ptr.is_null() {
182
- unsafe {
183
- let _ = Box::from_raw(metadata_ptr);
184
- };
185
- }
186
- }
187
-
188
- fn to_cstring(s: String) -> *mut c_char {
189
- CString::new(s).map_or(ptr::null_mut(), |cs| cs.into_raw())
190
- }
191
-
192
- #[unsafe(no_mangle)]
193
- pub extern "C" fn ztensor_metadata_get_name(metadata_ptr: *const CTensorMetadata) -> *mut c_char {
194
- let metadata = unsafe { metadata_ptr.as_ref().expect("Null metadata pointer") };
195
- to_cstring(metadata.name.clone())
196
- }
197
-
198
- #[unsafe(no_mangle)]
199
- pub extern "C" fn ztensor_metadata_get_dtype_str(
200
- metadata_ptr: *const CTensorMetadata,
201
- ) -> *mut c_char {
202
- let metadata = unsafe { metadata_ptr.as_ref().expect("Null metadata pointer") };
203
- to_cstring(metadata.dtype.to_string_key())
204
- }
205
-
206
- // ... other metadata accessors like get_offset, get_size, etc.
207
-
208
- /// Returns the number of dimensions in the tensor's shape.
209
- #[unsafe(no_mangle)]
210
- pub extern "C" fn ztensor_metadata_get_shape_len(metadata_ptr: *const CTensorMetadata) -> size_t {
211
- let metadata = unsafe { metadata_ptr.as_ref().expect("Null metadata pointer") };
212
- metadata.shape.len()
213
- }
214
-
215
- /// Returns a pointer to the shape data (an array of u64).
216
- /// The caller owns this memory and must free it with `ztensor_free_u64_array`.
217
- #[unsafe(no_mangle)]
218
- pub extern "C" fn ztensor_metadata_get_shape_data(
219
- metadata_ptr: *const CTensorMetadata,
220
- ) -> *mut u64 {
221
- let metadata = unsafe { metadata_ptr.as_ref().expect("Null metadata pointer") };
222
- // Clone the shape into a new Vec, get its raw pointer, and forget it
223
- // so Rust doesn't deallocate it. The C side is now responsible.
224
- let mut shape_vec = metadata.shape.clone();
225
- let ptr = shape_vec.as_mut_ptr();
226
- std::mem::forget(shape_vec);
227
- ptr
228
- }
229
-
230
- /// Frees the shape array allocated by `ztensor_metadata_get_shape_data`.
231
- #[unsafe(no_mangle)]
232
- pub extern "C" fn ztensor_free_u64_array(ptr: *mut u64, len: size_t) {
233
- if !ptr.is_null() {
234
- unsafe {
235
- // Reconstitute the Vec from the raw parts and let it drop, freeing the memory.
236
- let _ = Vec::from_raw_parts(ptr, len, len);
237
- }
238
- }
239
- }
240
-
241
- // --- Writer Functions ---
226
+ // --- Writer API ---
242
227
 
243
228
  #[unsafe(no_mangle)]
244
229
  pub extern "C" fn ztensor_writer_create(path_str: *const c_char) -> *mut CZTensorWriter {
@@ -246,9 +231,8 @@ pub extern "C" fn ztensor_writer_create(path_str: *const c_char) -> *mut CZTenso
246
231
  update_last_error(ZTensorError::Other("Null path provided".into()));
247
232
  return ptr::null_mut();
248
233
  }
249
- let path = unsafe { CStr::from_ptr(path_str) };
250
- let path = match path.to_str().map(Path::new) {
251
- Ok(p) => p,
234
+ let path = match unsafe { CStr::from_ptr(path_str).to_str() } {
235
+ Ok(s) => Path::new(s),
252
236
  Err(_) => {
253
237
  update_last_error(ZTensorError::Other("Invalid UTF-8 path".into()));
254
238
  return ptr::null_mut();
@@ -264,14 +248,6 @@ pub extern "C" fn ztensor_writer_create(path_str: *const c_char) -> *mut CZTenso
264
248
  }
265
249
  }
266
250
 
267
- #[unsafe(no_mangle)]
268
- pub extern "C" fn ztensor_writer_free(writer_ptr: *mut CZTensorWriter) {
269
- if !writer_ptr.is_null() {
270
- let _ = unsafe { Box::from_raw(writer_ptr) };
271
- // Dropping the writer will close the file.
272
- }
273
- }
274
-
275
251
  #[unsafe(no_mangle)]
276
252
  pub extern "C" fn ztensor_writer_add_tensor(
277
253
  writer_ptr: *mut CZTensorWriter,
@@ -282,16 +258,26 @@ pub extern "C" fn ztensor_writer_add_tensor(
282
258
  data_ptr: *const c_uchar,
283
259
  data_len: size_t,
284
260
  ) -> c_int {
285
- let writer = unsafe { writer_ptr.as_mut().expect("Null writer pointer") };
261
+ let writer = ztensor_handle!(mut writer_ptr, -1);
286
262
  let name = unsafe { CStr::from_ptr(name_str).to_str().unwrap() };
287
263
  let shape = unsafe { slice::from_raw_parts(shape_ptr, shape_len) };
288
264
  let dtype_str = unsafe { CStr::from_ptr(dtype_str).to_str().unwrap() };
289
265
  let data = unsafe { slice::from_raw_parts(data_ptr, data_len) };
290
266
 
291
267
  let dtype = match dtype_str {
268
+ "float64" => DType::Float64,
292
269
  "float32" => DType::Float32,
270
+ "float16" => DType::Float16,
271
+ "bfloat16" => DType::BFloat16,
272
+ "int64" => DType::Int64,
273
+ "int32" => DType::Int32,
274
+ "int16" => DType::Int16,
275
+ "int8" => DType::Int8,
276
+ "uint64" => DType::Uint64,
277
+ "uint32" => DType::Uint32,
278
+ "uint16" => DType::Uint16,
293
279
  "uint8" => DType::Uint8,
294
- // ... add other dtypes
280
+ "bool" => DType::Bool,
295
281
  _ => {
296
282
  update_last_error(ZTensorError::UnsupportedDType(dtype_str.to_string()));
297
283
  return -1;
@@ -305,7 +291,7 @@ pub extern "C" fn ztensor_writer_add_tensor(
305
291
  Layout::Dense,
306
292
  Encoding::Raw,
307
293
  data.to_vec(),
308
- Some(DataEndianness::Little), // Defaulting to little, could be a parameter
294
+ Some(DataEndianness::Little),
309
295
  ChecksumAlgorithm::None,
310
296
  None,
311
297
  );
@@ -324,8 +310,6 @@ pub extern "C" fn ztensor_writer_finalize(writer_ptr: *mut CZTensorWriter) -> c_
324
310
  if writer_ptr.is_null() {
325
311
  return -1;
326
312
  }
327
- // `from_raw` takes ownership and will drop the writer when it goes out of scope.
328
- // The writer's `drop` implementation should handle finalization.
329
313
  let writer = unsafe { Box::from_raw(writer_ptr) };
330
314
  match writer.finalize() {
331
315
  Ok(_) => 0,
@@ -336,13 +320,147 @@ pub extern "C" fn ztensor_writer_finalize(writer_ptr: *mut CZTensorWriter) -> c_
336
320
  }
337
321
  }
338
322
 
339
- // --- Memory Freeing for C ---
323
+ // --- Metadata API ---
324
+
325
+ #[unsafe(no_mangle)]
326
+ pub extern "C" fn ztensor_metadata_get_name(metadata_ptr: *const CTensorMetadata) -> *mut c_char {
327
+ let metadata = ztensor_handle!(metadata_ptr);
328
+ to_cstring(metadata.name.clone())
329
+ }
330
+
331
+ #[unsafe(no_mangle)]
332
+ pub extern "C" fn ztensor_metadata_get_dtype_str(
333
+ metadata_ptr: *const CTensorMetadata,
334
+ ) -> *mut c_char {
335
+ let metadata = ztensor_handle!(metadata_ptr);
336
+ to_cstring(metadata.dtype.to_string_key())
337
+ }
338
+
339
+ #[unsafe(no_mangle)]
340
+ pub extern "C" fn ztensor_metadata_get_offset(metadata_ptr: *const CTensorMetadata) -> u64 {
341
+ let metadata = ztensor_handle!(metadata_ptr, 0);
342
+ metadata.offset
343
+ }
344
+
345
+ #[unsafe(no_mangle)]
346
+ pub extern "C" fn ztensor_metadata_get_size(metadata_ptr: *const CTensorMetadata) -> u64 {
347
+ let metadata = ztensor_handle!(metadata_ptr, 0);
348
+ metadata.size
349
+ }
350
+
351
+ #[unsafe(no_mangle)]
352
+ pub extern "C" fn ztensor_metadata_get_layout_str(
353
+ metadata_ptr: *const CTensorMetadata,
354
+ ) -> *mut c_char {
355
+ let metadata = ztensor_handle!(metadata_ptr);
356
+ to_cstring(format!("{:?}", metadata.layout).to_lowercase())
357
+ }
358
+
359
+ #[unsafe(no_mangle)]
360
+ pub extern "C" fn ztensor_metadata_get_encoding_str(
361
+ metadata_ptr: *const CTensorMetadata,
362
+ ) -> *mut c_char {
363
+ let metadata = ztensor_handle!(metadata_ptr);
364
+ to_cstring(format!("{:?}", metadata.encoding).to_lowercase())
365
+ }
366
+
367
+ #[unsafe(no_mangle)]
368
+ pub extern "C" fn ztensor_metadata_get_data_endianness_str(
369
+ metadata_ptr: *const CTensorMetadata,
370
+ ) -> *mut c_char {
371
+ let metadata = ztensor_handle!(metadata_ptr);
372
+ match &metadata.data_endianness {
373
+ Some(endianness) => to_cstring(format!("{:?}", endianness).to_lowercase()),
374
+ None => ptr::null_mut(),
375
+ }
376
+ }
377
+
378
+ #[unsafe(no_mangle)]
379
+ pub extern "C" fn ztensor_metadata_get_checksum_str(
380
+ metadata_ptr: *const CTensorMetadata,
381
+ ) -> *mut c_char {
382
+ let metadata = ztensor_handle!(metadata_ptr);
383
+ match &metadata.checksum {
384
+ Some(s) => to_cstring(s.clone()),
385
+ None => ptr::null_mut(),
386
+ }
387
+ }
388
+
389
+ #[unsafe(no_mangle)]
390
+ pub extern "C" fn ztensor_metadata_get_shape_len(metadata_ptr: *const CTensorMetadata) -> size_t {
391
+ let metadata = ztensor_handle!(metadata_ptr, 0);
392
+ metadata.shape.len()
393
+ }
394
+
395
+ #[unsafe(no_mangle)]
396
+ pub extern "C" fn ztensor_metadata_get_shape_data(
397
+ metadata_ptr: *const CTensorMetadata,
398
+ ) -> *mut u64 {
399
+ let metadata = ztensor_handle!(metadata_ptr);
400
+ let mut shape_vec = metadata.shape.clone();
401
+ let ptr = shape_vec.as_mut_ptr();
402
+ std::mem::forget(shape_vec);
403
+ ptr
404
+ }
405
+
406
+ // --- Memory Management API ---
407
+
408
+ #[unsafe(no_mangle)]
409
+ pub extern "C" fn ztensor_reader_free(reader_ptr: *mut CZTensorReader) {
410
+ if !reader_ptr.is_null() {
411
+ let _ = unsafe { Box::from_raw(reader_ptr) };
412
+ }
413
+ }
414
+
415
+ #[unsafe(no_mangle)]
416
+ pub extern "C" fn ztensor_writer_free(writer_ptr: *mut CZTensorWriter) {
417
+ if !writer_ptr.is_null() {
418
+ let _ = unsafe { Box::from_raw(writer_ptr) };
419
+ }
420
+ }
421
+
422
+ #[unsafe(no_mangle)]
423
+ pub extern "C" fn ztensor_metadata_free(metadata_ptr: *mut CTensorMetadata) {
424
+ if !metadata_ptr.is_null() {
425
+ let _ = unsafe { Box::from_raw(metadata_ptr) };
426
+ }
427
+ }
428
+
429
+ #[unsafe(no_mangle)]
430
+ pub extern "C" fn ztensor_free_tensor_view(view_ptr: *mut CTensorDataView) {
431
+ if !view_ptr.is_null() {
432
+ unsafe {
433
+ let view = Box::from_raw(view_ptr);
434
+ // This reconstitutes the `Vec<u8>` and allows it to be dropped.
435
+ let _ = Box::from_raw(view._owner as *mut Vec<u8>);
436
+ }
437
+ }
438
+ }
340
439
 
341
440
  #[unsafe(no_mangle)]
342
441
  pub extern "C" fn ztensor_free_string(s: *mut c_char) {
343
442
  if !s.is_null() {
344
- unsafe {
345
- let _ = CString::from_raw(s);
346
- };
443
+ let _ = unsafe { CString::from_raw(s) };
444
+ }
445
+ }
446
+
447
+ #[unsafe(no_mangle)]
448
+ pub extern "C" fn ztensor_free_string_array(arr_ptr: *mut CStringArray) {
449
+ if arr_ptr.is_null() {
450
+ return;
451
+ }
452
+ unsafe {
453
+ let arr = Box::from_raw(arr_ptr);
454
+ let strings = Vec::from_raw_parts(arr.strings, arr.len, arr.len);
455
+ for s_ptr in strings {
456
+ let _ = CString::from_raw(s_ptr);
457
+ }
458
+ }
459
+ }
460
+
461
+ #[unsafe(no_mangle)]
462
+ pub extern "C" fn ztensor_free_u64_array(ptr: *mut u64, len: size_t) {
463
+ if !ptr.is_null() {
464
+ let _ = unsafe { Vec::from_raw_parts(ptr, len, len) };
347
465
  }
348
466
  }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes