ztensor 0.1.2__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ztensor might be problematic. Click here for more details.
- {ztensor-0.1.2 → ztensor-0.1.4}/PKG-INFO +1 -1
- {ztensor-0.1.2 → ztensor-0.1.4}/pyproject.toml +1 -1
- ztensor-0.1.4/python/examples/test.py +171 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/python/ztensor/__init__.py +108 -2
- {ztensor-0.1.2 → ztensor-0.1.4}/src/ffi.rs +272 -154
- {ztensor-0.1.2 → ztensor-0.1.4}/.github/workflows/CI.yml +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/.gitignore +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/Cargo.lock +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/Cargo.toml +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/LICENSE +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/README.md +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/python/LICENSE +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/python/README.md +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/python/examples/basic.py +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/src/error.rs +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/src/lib.rs +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/src/models.rs +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/src/reader.rs +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/src/utils.rs +0 -0
- {ztensor-0.1.2 → ztensor-0.1.4}/src/writer.rs +0 -0
|
@@ -10,7 +10,7 @@ build-backend = "maturin"
|
|
|
10
10
|
[project]
|
|
11
11
|
# The name of your package on PyPI.
|
|
12
12
|
name = "ztensor"
|
|
13
|
-
version = "0.1.
|
|
13
|
+
version = "0.1.4"
|
|
14
14
|
description = "Python bindings for the zTensor library."
|
|
15
15
|
readme = "README.md" # It's good practice to have a README.
|
|
16
16
|
authors = [
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
import os
|
|
3
|
+
import shutil
|
|
4
|
+
import tempfile
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
# --- Conditional PyTorch Import ---
|
|
8
|
+
try:
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
TORCH_AVAILABLE = True
|
|
12
|
+
except ImportError:
|
|
13
|
+
TORCH_AVAILABLE = False
|
|
14
|
+
|
|
15
|
+
# --- Import the ztensor wrapper ---
|
|
16
|
+
# Assuming the bindings file is named 'bindings.py' in a 'ztensor' package/directory
|
|
17
|
+
from ztensor import Reader, Writer, ZTensorError, TensorMetadata
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TestZTensorBindings(unittest.TestCase):
|
|
21
|
+
|
|
22
|
+
def setUp(self):
|
|
23
|
+
"""Set up a temporary directory for test files."""
|
|
24
|
+
self.test_dir = tempfile.mkdtemp()
|
|
25
|
+
self.test_file = os.path.join(self.test_dir, "test.zt")
|
|
26
|
+
|
|
27
|
+
def tearDown(self):
|
|
28
|
+
"""Remove the temporary directory and its contents."""
|
|
29
|
+
shutil.rmtree(self.test_dir)
|
|
30
|
+
|
|
31
|
+
def test_01_writer_and_reader_numpy(self):
|
|
32
|
+
"""Test writing and reading a single NumPy tensor."""
|
|
33
|
+
tensor_a = np.arange(24, dtype=np.float32).reshape(2, 3, 4)
|
|
34
|
+
tensor_b = np.array([[True, False], [False, True]], dtype=bool)
|
|
35
|
+
|
|
36
|
+
with Writer(self.test_file) as writer:
|
|
37
|
+
writer.add_tensor("tensor_a", tensor_a)
|
|
38
|
+
writer.add_tensor("tensor_b", tensor_b)
|
|
39
|
+
|
|
40
|
+
self.assertTrue(os.path.exists(self.test_file))
|
|
41
|
+
|
|
42
|
+
with Reader(self.test_file) as reader:
|
|
43
|
+
read_a = reader.read_tensor("tensor_a")
|
|
44
|
+
read_b = reader.read_tensor("tensor_b")
|
|
45
|
+
|
|
46
|
+
self.assertIsInstance(read_a, np.ndarray)
|
|
47
|
+
self.assertTrue(np.array_equal(tensor_a, read_a))
|
|
48
|
+
self.assertEqual(tensor_a.dtype, read_a.dtype)
|
|
49
|
+
self.assertEqual(tensor_a.shape, read_a.shape)
|
|
50
|
+
|
|
51
|
+
self.assertIsInstance(read_b, np.ndarray)
|
|
52
|
+
self.assertTrue(np.array_equal(tensor_b, read_b))
|
|
53
|
+
|
|
54
|
+
@unittest.skipIf(not TORCH_AVAILABLE, "PyTorch not installed")
|
|
55
|
+
def test_02_writer_and_reader_torch(self):
|
|
56
|
+
"""Test writing and reading a PyTorch tensor."""
|
|
57
|
+
tensor_a = torch.randn(10, 20, dtype=torch.float16)
|
|
58
|
+
tensor_b = torch.randint(0, 255, (100,), dtype=torch.uint8)
|
|
59
|
+
|
|
60
|
+
with Writer(self.test_file) as writer:
|
|
61
|
+
writer.add_tensor("torch_a", tensor_a)
|
|
62
|
+
writer.add_tensor("torch_b", tensor_b)
|
|
63
|
+
|
|
64
|
+
with Reader(self.test_file) as reader:
|
|
65
|
+
# Read back as torch tensor
|
|
66
|
+
read_a_torch = reader.read_tensor("torch_a", to='torch')
|
|
67
|
+
self.assertIsInstance(read_a_torch, torch.Tensor)
|
|
68
|
+
self.assertTrue(torch.equal(tensor_a, read_a_torch))
|
|
69
|
+
self.assertEqual(tensor_a.dtype, read_a_torch.dtype)
|
|
70
|
+
|
|
71
|
+
# Read back as numpy array
|
|
72
|
+
read_b_np = reader.read_tensor("torch_b", to='numpy')
|
|
73
|
+
self.assertIsInstance(read_b_np, np.ndarray)
|
|
74
|
+
self.assertTrue(np.array_equal(tensor_b.numpy(), read_b_np))
|
|
75
|
+
|
|
76
|
+
def test_03_metadata_access_and_iteration(self):
|
|
77
|
+
"""Test the reader's container and metadata features."""
|
|
78
|
+
tensors = {
|
|
79
|
+
"tensor_int": np.ones((5, 5), dtype=np.int32),
|
|
80
|
+
"tensor_float": np.zeros((10,), dtype=np.float64),
|
|
81
|
+
"scalar": np.array(3.14, dtype=np.float32)
|
|
82
|
+
}
|
|
83
|
+
tensor_names = sorted(tensors.keys())
|
|
84
|
+
|
|
85
|
+
with Writer(self.test_file) as writer:
|
|
86
|
+
for name, tensor in tensors.items():
|
|
87
|
+
writer.add_tensor(name, tensor)
|
|
88
|
+
|
|
89
|
+
with Reader(self.test_file) as reader:
|
|
90
|
+
# Test __len__
|
|
91
|
+
self.assertEqual(len(reader), 3)
|
|
92
|
+
|
|
93
|
+
# Test get_tensor_names
|
|
94
|
+
read_names = sorted(reader.get_tensor_names())
|
|
95
|
+
self.assertEqual(tensor_names, read_names)
|
|
96
|
+
|
|
97
|
+
# Test __iter__ and __getitem__
|
|
98
|
+
all_meta = []
|
|
99
|
+
for i in range(len(reader)):
|
|
100
|
+
meta = reader[i]
|
|
101
|
+
self.assertIsInstance(meta, TensorMetadata)
|
|
102
|
+
all_meta.append(meta.name)
|
|
103
|
+
self.assertEqual(tensor_names, sorted(all_meta))
|
|
104
|
+
|
|
105
|
+
# Test list_tensors
|
|
106
|
+
self.assertEqual(len(reader.list_tensors()), 3)
|
|
107
|
+
|
|
108
|
+
# Test specific metadata properties
|
|
109
|
+
meta_scalar = reader.get_metadata("scalar")
|
|
110
|
+
self.assertEqual(meta_scalar.name, "scalar")
|
|
111
|
+
self.assertEqual(meta_scalar.shape, (1, ))
|
|
112
|
+
self.assertEqual(meta_scalar.dtype, np.dtype('float32'))
|
|
113
|
+
self.assertEqual(meta_scalar.dtype_str, 'float32')
|
|
114
|
+
self.assertGreater(meta_scalar.offset, 0)
|
|
115
|
+
self.assertGreater(meta_scalar.size, 0)
|
|
116
|
+
self.assertEqual(meta_scalar.layout, "dense")
|
|
117
|
+
self.assertEqual(meta_scalar.encoding, "raw")
|
|
118
|
+
self.assertIn(meta_scalar.endianness, ["little", "big"])
|
|
119
|
+
self.assertIsNone(meta_scalar.checksum)
|
|
120
|
+
|
|
121
|
+
def test_04_error_handling(self):
|
|
122
|
+
"""Test expected failure modes."""
|
|
123
|
+
tensor = np.arange(10)
|
|
124
|
+
with Writer(self.test_file) as writer:
|
|
125
|
+
writer.add_tensor("my_tensor", tensor)
|
|
126
|
+
|
|
127
|
+
# Test reading non-existent tensor
|
|
128
|
+
with Reader(self.test_file) as reader:
|
|
129
|
+
with self.assertRaisesRegex(ZTensorError, "Tensor not found"):
|
|
130
|
+
reader.read_tensor("non_existent_tensor")
|
|
131
|
+
|
|
132
|
+
# Test accessing closed reader
|
|
133
|
+
reader = Reader(self.test_file)
|
|
134
|
+
reader.__exit__(None, None, None) # Manually close
|
|
135
|
+
with self.assertRaisesRegex(ZTensorError, "Reader is closed"):
|
|
136
|
+
len(reader)
|
|
137
|
+
with self.assertRaisesRegex(ZTensorError, "Reader is closed"):
|
|
138
|
+
reader.read_tensor("my_tensor")
|
|
139
|
+
|
|
140
|
+
# Test index out of range
|
|
141
|
+
with Reader(self.test_file) as reader:
|
|
142
|
+
with self.assertRaises(IndexError):
|
|
143
|
+
reader[99]
|
|
144
|
+
|
|
145
|
+
# Test writing to a finalized writer
|
|
146
|
+
writer = Writer(self.test_file)
|
|
147
|
+
writer.add_tensor("t1", tensor)
|
|
148
|
+
writer.finalize()
|
|
149
|
+
with self.assertRaisesRegex(ZTensorError, "Writer is closed or finalized."):
|
|
150
|
+
writer.add_tensor("t2", tensor)
|
|
151
|
+
|
|
152
|
+
def test_05_writer_context_manager_exception(self):
|
|
153
|
+
"""Ensure writer does not finalize on error inside `with` block."""
|
|
154
|
+
bad_file_path = os.path.join(self.test_dir, "bad_file.zt")
|
|
155
|
+
try:
|
|
156
|
+
with Writer(bad_file_path) as writer:
|
|
157
|
+
writer.add_tensor("good_tensor", np.arange(5))
|
|
158
|
+
raise ValueError("Simulating an error during writing")
|
|
159
|
+
except ValueError:
|
|
160
|
+
pass # Expected
|
|
161
|
+
|
|
162
|
+
# The file should exist but be empty/invalid because finalize was not called.
|
|
163
|
+
self.assertTrue(os.path.exists(bad_file_path))
|
|
164
|
+
# Attempting to read should fail
|
|
165
|
+
with self.assertRaises(ZTensorError):
|
|
166
|
+
with Reader(bad_file_path) as reader:
|
|
167
|
+
len(reader)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
if __name__ == '__main__':
|
|
171
|
+
unittest.main(verbosity=2)
|
|
@@ -92,12 +92,23 @@ class TensorMetadata:
|
|
|
92
92
|
def __init__(self, meta_ptr):
|
|
93
93
|
self._ptr = ffi.gc(meta_ptr, lib.ztensor_metadata_free)
|
|
94
94
|
_check_ptr(self._ptr, "TensorMetadata constructor")
|
|
95
|
+
# Cache for properties to avoid repeated FFI calls
|
|
95
96
|
self._name = None
|
|
96
97
|
self._dtype_str = None
|
|
97
98
|
self._shape = None
|
|
99
|
+
self._offset = None
|
|
100
|
+
self._size = None
|
|
101
|
+
self._layout = None
|
|
102
|
+
self._encoding = None
|
|
103
|
+
self._endianness = "not_checked"
|
|
104
|
+
self._checksum = "not_checked"
|
|
105
|
+
|
|
106
|
+
def __repr__(self):
|
|
107
|
+
return f"<TensorMetadata name='{self.name}' shape={self.shape} dtype='{self.dtype_str}'>"
|
|
98
108
|
|
|
99
109
|
@property
|
|
100
110
|
def name(self):
|
|
111
|
+
"""The name of the tensor."""
|
|
101
112
|
if self._name is None:
|
|
102
113
|
name_ptr = lib.ztensor_metadata_get_name(self._ptr)
|
|
103
114
|
_check_ptr(name_ptr, "get_name")
|
|
@@ -107,6 +118,7 @@ class TensorMetadata:
|
|
|
107
118
|
|
|
108
119
|
@property
|
|
109
120
|
def dtype_str(self):
|
|
121
|
+
"""The zTensor dtype string (e.g., 'float32')."""
|
|
110
122
|
if self._dtype_str is None:
|
|
111
123
|
dtype_ptr = lib.ztensor_metadata_get_dtype_str(self._ptr)
|
|
112
124
|
_check_ptr(dtype_ptr, "get_dtype_str")
|
|
@@ -116,7 +128,7 @@ class TensorMetadata:
|
|
|
116
128
|
|
|
117
129
|
@property
|
|
118
130
|
def dtype(self):
|
|
119
|
-
"""
|
|
131
|
+
"""The numpy dtype for this tensor."""
|
|
120
132
|
dtype_str = self.dtype_str
|
|
121
133
|
dt = DTYPE_ZT_TO_NP.get(dtype_str)
|
|
122
134
|
if dt is None:
|
|
@@ -130,6 +142,7 @@ class TensorMetadata:
|
|
|
130
142
|
|
|
131
143
|
@property
|
|
132
144
|
def shape(self):
|
|
145
|
+
"""The shape of the tensor as a tuple."""
|
|
133
146
|
if self._shape is None:
|
|
134
147
|
shape_len = lib.ztensor_metadata_get_shape_len(self._ptr)
|
|
135
148
|
if shape_len > 0:
|
|
@@ -141,6 +154,64 @@ class TensorMetadata:
|
|
|
141
154
|
self._shape = tuple()
|
|
142
155
|
return self._shape
|
|
143
156
|
|
|
157
|
+
@property
|
|
158
|
+
def offset(self):
|
|
159
|
+
"""The on-disk offset of the tensor data in bytes."""
|
|
160
|
+
if self._offset is None:
|
|
161
|
+
self._offset = lib.ztensor_metadata_get_offset(self._ptr)
|
|
162
|
+
return self._offset
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def size(self):
|
|
166
|
+
"""The on-disk size of the tensor data in bytes (can be compressed size)."""
|
|
167
|
+
if self._size is None:
|
|
168
|
+
self._size = lib.ztensor_metadata_get_size(self._ptr)
|
|
169
|
+
return self._size
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def layout(self):
|
|
173
|
+
"""The tensor layout as a string (e.g., 'dense')."""
|
|
174
|
+
if self._layout is None:
|
|
175
|
+
layout_ptr = lib.ztensor_metadata_get_layout_str(self._ptr)
|
|
176
|
+
_check_ptr(layout_ptr, "get_layout_str")
|
|
177
|
+
self._layout = ffi.string(layout_ptr).decode('utf-8')
|
|
178
|
+
lib.ztensor_free_string(layout_ptr)
|
|
179
|
+
return self._layout
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def encoding(self):
|
|
183
|
+
"""The tensor encoding as a string (e.g., 'raw', 'zstd')."""
|
|
184
|
+
if self._encoding is None:
|
|
185
|
+
encoding_ptr = lib.ztensor_metadata_get_encoding_str(self._ptr)
|
|
186
|
+
_check_ptr(encoding_ptr, "get_encoding_str")
|
|
187
|
+
self._encoding = ffi.string(encoding_ptr).decode('utf-8')
|
|
188
|
+
lib.ztensor_free_string(encoding_ptr)
|
|
189
|
+
return self._encoding
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def endianness(self):
|
|
193
|
+
"""The data endianness ('little', 'big') if applicable, else None."""
|
|
194
|
+
if self._endianness == "not_checked":
|
|
195
|
+
endian_ptr = lib.ztensor_metadata_get_data_endianness_str(self._ptr)
|
|
196
|
+
if endian_ptr == ffi.NULL:
|
|
197
|
+
self._endianness = None
|
|
198
|
+
else:
|
|
199
|
+
self._endianness = ffi.string(endian_ptr).decode('utf-8')
|
|
200
|
+
lib.ztensor_free_string(endian_ptr)
|
|
201
|
+
return self._endianness
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def checksum(self):
|
|
205
|
+
"""The checksum string if present, else None."""
|
|
206
|
+
if self._checksum == "not_checked":
|
|
207
|
+
checksum_ptr = lib.ztensor_metadata_get_checksum_str(self._ptr)
|
|
208
|
+
if checksum_ptr == ffi.NULL:
|
|
209
|
+
self._checksum = None
|
|
210
|
+
else:
|
|
211
|
+
self._checksum = ffi.string(checksum_ptr).decode('utf-8')
|
|
212
|
+
lib.ztensor_free_string(checksum_ptr)
|
|
213
|
+
return self._checksum
|
|
214
|
+
|
|
144
215
|
|
|
145
216
|
class Reader:
|
|
146
217
|
"""A Pythonic context manager for reading zTensor files."""
|
|
@@ -157,6 +228,39 @@ class Reader:
|
|
|
157
228
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
158
229
|
self._ptr = None
|
|
159
230
|
|
|
231
|
+
def __len__(self):
|
|
232
|
+
"""Returns the number of tensors in the file."""
|
|
233
|
+
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
234
|
+
return lib.ztensor_reader_get_metadata_count(self._ptr)
|
|
235
|
+
|
|
236
|
+
def __iter__(self):
|
|
237
|
+
"""Iterates over the metadata of all tensors in the file."""
|
|
238
|
+
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
239
|
+
for i in range(len(self)):
|
|
240
|
+
yield self[i]
|
|
241
|
+
|
|
242
|
+
def __getitem__(self, index: int) -> TensorMetadata:
|
|
243
|
+
"""Retrieves metadata for a tensor by its index."""
|
|
244
|
+
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
245
|
+
if index >= len(self):
|
|
246
|
+
raise IndexError("Tensor index out of range")
|
|
247
|
+
meta_ptr = lib.ztensor_reader_get_metadata_by_index(self._ptr, index)
|
|
248
|
+
_check_ptr(meta_ptr, f"get_metadata_by_index: {index}")
|
|
249
|
+
return TensorMetadata(meta_ptr)
|
|
250
|
+
|
|
251
|
+
def list_tensors(self) -> list[TensorMetadata]:
|
|
252
|
+
"""Returns a list of all TensorMetadata objects in the file."""
|
|
253
|
+
return list(self)
|
|
254
|
+
|
|
255
|
+
def get_tensor_names(self) -> list[str]:
|
|
256
|
+
"""Returns a list of all tensor names in the file."""
|
|
257
|
+
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
258
|
+
c_array_ptr = lib.ztensor_reader_get_all_tensor_names(self._ptr)
|
|
259
|
+
_check_ptr(c_array_ptr, "get_all_tensor_names")
|
|
260
|
+
c_array_ptr = ffi.gc(c_array_ptr, lib.ztensor_free_string_array)
|
|
261
|
+
|
|
262
|
+
return [ffi.string(c_array_ptr.strings[i]).decode('utf-8') for i in range(c_array_ptr.len)]
|
|
263
|
+
|
|
160
264
|
def get_metadata(self, name: str) -> TensorMetadata:
|
|
161
265
|
"""Retrieves metadata for a tensor by its name."""
|
|
162
266
|
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
@@ -177,6 +281,7 @@ class Reader:
|
|
|
177
281
|
Returns:
|
|
178
282
|
np.ndarray or torch.Tensor: The tensor data.
|
|
179
283
|
"""
|
|
284
|
+
if self._ptr is None: raise ZTensorError("Reader is closed.")
|
|
180
285
|
if to not in ['numpy', 'torch']:
|
|
181
286
|
raise ValueError(f"Unsupported format: '{to}'. Choose 'numpy' or 'torch'.")
|
|
182
287
|
|
|
@@ -238,6 +343,7 @@ class Writer:
|
|
|
238
343
|
if exc_type is None:
|
|
239
344
|
self.finalize()
|
|
240
345
|
else:
|
|
346
|
+
# If an error occurred, just free the handle without finalizing
|
|
241
347
|
lib.ztensor_writer_free(self._ptr)
|
|
242
348
|
self._ptr = None
|
|
243
349
|
|
|
@@ -294,7 +400,7 @@ class Writer:
|
|
|
294
400
|
"""Finalizes the zTensor file, writing the metadata index."""
|
|
295
401
|
if not self._ptr: raise ZTensorError("Writer is already closed or finalized.")
|
|
296
402
|
status = lib.ztensor_writer_finalize(self._ptr)
|
|
297
|
-
self._ptr = None
|
|
403
|
+
self._ptr = None # The writer is consumed in Rust
|
|
298
404
|
self._finalized = True
|
|
299
405
|
_check_status(status, "finalize")
|
|
300
406
|
|
|
@@ -12,7 +12,35 @@ use crate::models::{ChecksumAlgorithm, DType, DataEndianness, Encoding, Layout,
|
|
|
12
12
|
use crate::reader::ZTensorReader;
|
|
13
13
|
use crate::writer::ZTensorWriter;
|
|
14
14
|
|
|
15
|
+
// --- C-Compatible Structs & Handles ---
|
|
16
|
+
|
|
17
|
+
/// Opaque handle to a file-based zTensor reader.
|
|
18
|
+
pub type CZTensorReader = ZTensorReader<BufReader<std::fs::File>>;
|
|
19
|
+
/// Opaque handle to a file-based zTensor writer.
|
|
20
|
+
pub type CZTensorWriter = ZTensorWriter<BufWriter<std::fs::File>>;
|
|
21
|
+
/// Opaque handle to an in-memory zTensor writer.
|
|
22
|
+
pub type CInMemoryZTensorWriter = ZTensorWriter<Cursor<Vec<u8>>>;
|
|
23
|
+
/// Opaque, owned handle to a tensor's metadata.
|
|
24
|
+
pub type CTensorMetadata = TensorMetadata;
|
|
25
|
+
|
|
26
|
+
/// A self-contained, C-compatible view of owned tensor data.
|
|
27
|
+
#[repr(C)]
|
|
28
|
+
pub struct CTensorDataView {
|
|
29
|
+
pub data: *const c_uchar,
|
|
30
|
+
pub len: size_t,
|
|
31
|
+
// Private field holding the owned Vec<u8> data.
|
|
32
|
+
_owner: *mut c_void,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
/// A C-compatible, heap-allocated array of C strings.
|
|
36
|
+
#[repr(C)]
|
|
37
|
+
pub struct CStringArray {
|
|
38
|
+
pub strings: *mut *mut c_char,
|
|
39
|
+
pub len: size_t,
|
|
40
|
+
}
|
|
41
|
+
|
|
15
42
|
// --- Error Handling ---
|
|
43
|
+
|
|
16
44
|
lazy_static! {
|
|
17
45
|
static ref LAST_ERROR: Mutex<Option<CString>> = Mutex::new(None);
|
|
18
46
|
}
|
|
@@ -23,6 +51,10 @@ fn update_last_error(err: ZTensorError) {
|
|
|
23
51
|
*LAST_ERROR.lock().unwrap() = Some(msg);
|
|
24
52
|
}
|
|
25
53
|
|
|
54
|
+
/// Retrieves the last error message set by a failed API call.
|
|
55
|
+
///
|
|
56
|
+
/// The returned string is valid until the next API call.
|
|
57
|
+
/// Returns `null` if no error has occurred.
|
|
26
58
|
#[unsafe(no_mangle)]
|
|
27
59
|
pub extern "C" fn ztensor_last_error_message() -> *const c_char {
|
|
28
60
|
match LAST_ERROR.lock().unwrap().as_ref() {
|
|
@@ -31,30 +63,51 @@ pub extern "C" fn ztensor_last_error_message() -> *const c_char {
|
|
|
31
63
|
}
|
|
32
64
|
}
|
|
33
65
|
|
|
34
|
-
// ---
|
|
35
|
-
|
|
36
|
-
// Reader is now generic over R: Read + Seek.
|
|
37
|
-
// For the C API, we'll expose a version that works with files.
|
|
38
|
-
pub type CZTensorReader = ZTensorReader<BufReader<std::fs::File>>;
|
|
39
|
-
|
|
40
|
-
// Writer is also generic. We'll expose a file-based and in-memory version.
|
|
41
|
-
pub type CZTensorWriter = ZTensorWriter<BufWriter<std::fs::File>>;
|
|
42
|
-
pub type CInMemoryZTensorWriter = ZTensorWriter<Cursor<Vec<u8>>>;
|
|
66
|
+
// --- Internal Helpers ---
|
|
43
67
|
|
|
44
|
-
|
|
45
|
-
|
|
68
|
+
/// A macro to safely access the Rust object behind an opaque C pointer.
|
|
69
|
+
/// It checks for null and returns a safe reference, handling the error case.
|
|
70
|
+
macro_rules! ztensor_handle {
|
|
71
|
+
($ptr:expr) => {
|
|
72
|
+
if $ptr.is_null() {
|
|
73
|
+
update_last_error(ZTensorError::Other("Null pointer passed as handle".into()));
|
|
74
|
+
return ptr::null_mut();
|
|
75
|
+
} else {
|
|
76
|
+
unsafe { &*$ptr }
|
|
77
|
+
}
|
|
78
|
+
};
|
|
79
|
+
(mut $ptr:expr) => {
|
|
80
|
+
if $ptr.is_null() {
|
|
81
|
+
update_last_error(ZTensorError::Other("Null pointer passed as handle".into()));
|
|
82
|
+
return ptr::null_mut();
|
|
83
|
+
} else {
|
|
84
|
+
unsafe { &mut *$ptr }
|
|
85
|
+
}
|
|
86
|
+
};
|
|
87
|
+
($ptr:expr, $err_ret:expr) => {
|
|
88
|
+
if $ptr.is_null() {
|
|
89
|
+
update_last_error(ZTensorError::Other("Null pointer passed as handle".into()));
|
|
90
|
+
return $err_ret;
|
|
91
|
+
} else {
|
|
92
|
+
unsafe { &*$ptr }
|
|
93
|
+
}
|
|
94
|
+
};
|
|
95
|
+
(mut $ptr:expr, $err_ret:expr) => {
|
|
96
|
+
if $ptr.is_null() {
|
|
97
|
+
update_last_error(ZTensorError::Other("Null pointer passed as handle".into()));
|
|
98
|
+
return $err_ret;
|
|
99
|
+
} else {
|
|
100
|
+
unsafe { &mut *$ptr }
|
|
101
|
+
}
|
|
102
|
+
};
|
|
103
|
+
}
|
|
46
104
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
pub data: *const c_uchar,
|
|
51
|
-
pub len: size_t,
|
|
52
|
-
// Private field to hold the owned data, making this struct self-contained.
|
|
53
|
-
// The C side only sees a pointer to the data, but this struct owns the Vec.
|
|
54
|
-
_owner: *mut c_void,
|
|
105
|
+
/// Helper to convert a Rust String into a C-style, null-terminated string pointer.
|
|
106
|
+
fn to_cstring(s: String) -> *mut c_char {
|
|
107
|
+
CString::new(s).map_or(ptr::null_mut(), |cs| cs.into_raw())
|
|
55
108
|
}
|
|
56
109
|
|
|
57
|
-
// --- Reader
|
|
110
|
+
// --- Reader API ---
|
|
58
111
|
|
|
59
112
|
#[unsafe(no_mangle)]
|
|
60
113
|
pub extern "C" fn ztensor_reader_open(path_str: *const c_char) -> *mut CZTensorReader {
|
|
@@ -62,11 +115,8 @@ pub extern "C" fn ztensor_reader_open(path_str: *const c_char) -> *mut CZTensorR
|
|
|
62
115
|
update_last_error(ZTensorError::Other("Null path provided".into()));
|
|
63
116
|
return ptr::null_mut();
|
|
64
117
|
}
|
|
65
|
-
let path = unsafe { CStr::from_ptr(path_str) }
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
let path = match path_res {
|
|
69
|
-
Ok(p) => p,
|
|
118
|
+
let path = match unsafe { CStr::from_ptr(path_str).to_str() } {
|
|
119
|
+
Ok(s) => Path::new(s),
|
|
70
120
|
Err(_) => {
|
|
71
121
|
update_last_error(ZTensorError::Other("Invalid UTF-8 path".into()));
|
|
72
122
|
return ptr::null_mut();
|
|
@@ -82,22 +132,9 @@ pub extern "C" fn ztensor_reader_open(path_str: *const c_char) -> *mut CZTensorR
|
|
|
82
132
|
}
|
|
83
133
|
}
|
|
84
134
|
|
|
85
|
-
#[unsafe(no_mangle)]
|
|
86
|
-
pub extern "C" fn ztensor_reader_free(reader_ptr: *mut CZTensorReader) {
|
|
87
|
-
if !reader_ptr.is_null() {
|
|
88
|
-
unsafe {
|
|
89
|
-
let _ = Box::from_raw(reader_ptr);
|
|
90
|
-
};
|
|
91
|
-
}
|
|
92
|
-
}
|
|
93
|
-
|
|
94
135
|
#[unsafe(no_mangle)]
|
|
95
136
|
pub extern "C" fn ztensor_reader_get_metadata_count(reader_ptr: *const CZTensorReader) -> size_t {
|
|
96
|
-
let reader =
|
|
97
|
-
reader_ptr
|
|
98
|
-
.as_ref()
|
|
99
|
-
.expect("Null pointer passed to ztensor_reader_get_metadata_count")
|
|
100
|
-
};
|
|
137
|
+
let reader = ztensor_handle!(reader_ptr, 0);
|
|
101
138
|
reader.list_tensors().len()
|
|
102
139
|
}
|
|
103
140
|
|
|
@@ -106,11 +143,7 @@ pub extern "C" fn ztensor_reader_get_metadata_by_name(
|
|
|
106
143
|
reader_ptr: *const CZTensorReader,
|
|
107
144
|
name_str: *const c_char,
|
|
108
145
|
) -> *mut CTensorMetadata {
|
|
109
|
-
let reader =
|
|
110
|
-
reader_ptr
|
|
111
|
-
.as_ref()
|
|
112
|
-
.expect("Null pointer passed to ztensor_reader_get_metadata_by_name")
|
|
113
|
-
};
|
|
146
|
+
let reader = ztensor_handle!(reader_ptr);
|
|
114
147
|
if name_str.is_null() {
|
|
115
148
|
update_last_error(ZTensorError::Other("Null name pointer provided".into()));
|
|
116
149
|
return ptr::null_mut();
|
|
@@ -126,24 +159,53 @@ pub extern "C" fn ztensor_reader_get_metadata_by_name(
|
|
|
126
159
|
}
|
|
127
160
|
}
|
|
128
161
|
|
|
129
|
-
|
|
162
|
+
#[unsafe(no_mangle)]
|
|
163
|
+
pub extern "C" fn ztensor_reader_get_metadata_by_index(
|
|
164
|
+
reader_ptr: *const CZTensorReader,
|
|
165
|
+
index: size_t,
|
|
166
|
+
) -> *mut CTensorMetadata {
|
|
167
|
+
let reader = ztensor_handle!(reader_ptr);
|
|
168
|
+
match reader.list_tensors().get(index) {
|
|
169
|
+
Some(metadata) => Box::into_raw(Box::new(metadata.clone())),
|
|
170
|
+
None => {
|
|
171
|
+
update_last_error(ZTensorError::Other(format!(
|
|
172
|
+
"Index {} is out of bounds for tensor list of length {}",
|
|
173
|
+
index,
|
|
174
|
+
reader.list_tensors().len()
|
|
175
|
+
)));
|
|
176
|
+
ptr::null_mut()
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
#[unsafe(no_mangle)]
|
|
182
|
+
pub extern "C" fn ztensor_reader_get_all_tensor_names(
|
|
183
|
+
reader_ptr: *const CZTensorReader,
|
|
184
|
+
) -> *mut CStringArray {
|
|
185
|
+
let reader = ztensor_handle!(reader_ptr);
|
|
186
|
+
let names: Vec<CString> = reader
|
|
187
|
+
.list_tensors()
|
|
188
|
+
.iter()
|
|
189
|
+
.map(|m| CString::new(m.name.as_str()).unwrap())
|
|
190
|
+
.collect();
|
|
191
|
+
|
|
192
|
+
let mut c_names: Vec<*mut c_char> = names.into_iter().map(|s| s.into_raw()).collect();
|
|
193
|
+
let string_array = Box::new(CStringArray {
|
|
194
|
+
strings: c_names.as_mut_ptr(),
|
|
195
|
+
len: c_names.len(),
|
|
196
|
+
});
|
|
197
|
+
|
|
198
|
+
std::mem::forget(c_names); // C side is now responsible for this memory
|
|
199
|
+
Box::into_raw(string_array)
|
|
200
|
+
}
|
|
130
201
|
|
|
131
|
-
/// Reads tensor data into a view without copying. The view must be freed.
|
|
132
202
|
#[unsafe(no_mangle)]
|
|
133
203
|
pub extern "C" fn ztensor_reader_read_tensor_view(
|
|
134
204
|
reader_ptr: *mut CZTensorReader,
|
|
135
205
|
metadata_ptr: *const CTensorMetadata,
|
|
136
206
|
) -> *mut CTensorDataView {
|
|
137
|
-
let reader =
|
|
138
|
-
|
|
139
|
-
.as_mut()
|
|
140
|
-
.expect("Null reader pointer to read_tensor_view")
|
|
141
|
-
};
|
|
142
|
-
let metadata = unsafe {
|
|
143
|
-
metadata_ptr
|
|
144
|
-
.as_ref()
|
|
145
|
-
.expect("Null metadata pointer to read_tensor_view")
|
|
146
|
-
};
|
|
207
|
+
let reader = ztensor_handle!(mut reader_ptr);
|
|
208
|
+
let metadata = ztensor_handle!(metadata_ptr);
|
|
147
209
|
|
|
148
210
|
match reader.read_raw_tensor_data(metadata) {
|
|
149
211
|
Ok(data_vec) => {
|
|
@@ -161,84 +223,7 @@ pub extern "C" fn ztensor_reader_read_tensor_view(
|
|
|
161
223
|
}
|
|
162
224
|
}
|
|
163
225
|
|
|
164
|
-
|
|
165
|
-
#[unsafe(no_mangle)]
|
|
166
|
-
pub extern "C" fn ztensor_free_tensor_view(view_ptr: *mut CTensorDataView) {
|
|
167
|
-
if !view_ptr.is_null() {
|
|
168
|
-
unsafe {
|
|
169
|
-
let view = Box::from_raw(view_ptr);
|
|
170
|
-
// This will drop the Vec<u8> that `_owner` points to
|
|
171
|
-
let _ = Box::from_raw(view._owner as *mut Vec<u8>);
|
|
172
|
-
}
|
|
173
|
-
}
|
|
174
|
-
}
|
|
175
|
-
|
|
176
|
-
// --- Metadata Accessors ---
|
|
177
|
-
// (These are largely the same but ensured to be safe)
|
|
178
|
-
|
|
179
|
-
#[unsafe(no_mangle)]
|
|
180
|
-
pub extern "C" fn ztensor_metadata_free(metadata_ptr: *mut CTensorMetadata) {
|
|
181
|
-
if !metadata_ptr.is_null() {
|
|
182
|
-
unsafe {
|
|
183
|
-
let _ = Box::from_raw(metadata_ptr);
|
|
184
|
-
};
|
|
185
|
-
}
|
|
186
|
-
}
|
|
187
|
-
|
|
188
|
-
fn to_cstring(s: String) -> *mut c_char {
|
|
189
|
-
CString::new(s).map_or(ptr::null_mut(), |cs| cs.into_raw())
|
|
190
|
-
}
|
|
191
|
-
|
|
192
|
-
#[unsafe(no_mangle)]
|
|
193
|
-
pub extern "C" fn ztensor_metadata_get_name(metadata_ptr: *const CTensorMetadata) -> *mut c_char {
|
|
194
|
-
let metadata = unsafe { metadata_ptr.as_ref().expect("Null metadata pointer") };
|
|
195
|
-
to_cstring(metadata.name.clone())
|
|
196
|
-
}
|
|
197
|
-
|
|
198
|
-
#[unsafe(no_mangle)]
|
|
199
|
-
pub extern "C" fn ztensor_metadata_get_dtype_str(
|
|
200
|
-
metadata_ptr: *const CTensorMetadata,
|
|
201
|
-
) -> *mut c_char {
|
|
202
|
-
let metadata = unsafe { metadata_ptr.as_ref().expect("Null metadata pointer") };
|
|
203
|
-
to_cstring(metadata.dtype.to_string_key())
|
|
204
|
-
}
|
|
205
|
-
|
|
206
|
-
// ... other metadata accessors like get_offset, get_size, etc.
|
|
207
|
-
|
|
208
|
-
/// Returns the number of dimensions in the tensor's shape.
|
|
209
|
-
#[unsafe(no_mangle)]
|
|
210
|
-
pub extern "C" fn ztensor_metadata_get_shape_len(metadata_ptr: *const CTensorMetadata) -> size_t {
|
|
211
|
-
let metadata = unsafe { metadata_ptr.as_ref().expect("Null metadata pointer") };
|
|
212
|
-
metadata.shape.len()
|
|
213
|
-
}
|
|
214
|
-
|
|
215
|
-
/// Returns a pointer to the shape data (an array of u64).
|
|
216
|
-
/// The caller owns this memory and must free it with `ztensor_free_u64_array`.
|
|
217
|
-
#[unsafe(no_mangle)]
|
|
218
|
-
pub extern "C" fn ztensor_metadata_get_shape_data(
|
|
219
|
-
metadata_ptr: *const CTensorMetadata,
|
|
220
|
-
) -> *mut u64 {
|
|
221
|
-
let metadata = unsafe { metadata_ptr.as_ref().expect("Null metadata pointer") };
|
|
222
|
-
// Clone the shape into a new Vec, get its raw pointer, and forget it
|
|
223
|
-
// so Rust doesn't deallocate it. The C side is now responsible.
|
|
224
|
-
let mut shape_vec = metadata.shape.clone();
|
|
225
|
-
let ptr = shape_vec.as_mut_ptr();
|
|
226
|
-
std::mem::forget(shape_vec);
|
|
227
|
-
ptr
|
|
228
|
-
}
|
|
229
|
-
|
|
230
|
-
/// Frees the shape array allocated by `ztensor_metadata_get_shape_data`.
|
|
231
|
-
#[unsafe(no_mangle)]
|
|
232
|
-
pub extern "C" fn ztensor_free_u64_array(ptr: *mut u64, len: size_t) {
|
|
233
|
-
if !ptr.is_null() {
|
|
234
|
-
unsafe {
|
|
235
|
-
// Reconstitute the Vec from the raw parts and let it drop, freeing the memory.
|
|
236
|
-
let _ = Vec::from_raw_parts(ptr, len, len);
|
|
237
|
-
}
|
|
238
|
-
}
|
|
239
|
-
}
|
|
240
|
-
|
|
241
|
-
// --- Writer Functions ---
|
|
226
|
+
// --- Writer API ---
|
|
242
227
|
|
|
243
228
|
#[unsafe(no_mangle)]
|
|
244
229
|
pub extern "C" fn ztensor_writer_create(path_str: *const c_char) -> *mut CZTensorWriter {
|
|
@@ -246,9 +231,8 @@ pub extern "C" fn ztensor_writer_create(path_str: *const c_char) -> *mut CZTenso
|
|
|
246
231
|
update_last_error(ZTensorError::Other("Null path provided".into()));
|
|
247
232
|
return ptr::null_mut();
|
|
248
233
|
}
|
|
249
|
-
let path = unsafe { CStr::from_ptr(path_str) }
|
|
250
|
-
|
|
251
|
-
Ok(p) => p,
|
|
234
|
+
let path = match unsafe { CStr::from_ptr(path_str).to_str() } {
|
|
235
|
+
Ok(s) => Path::new(s),
|
|
252
236
|
Err(_) => {
|
|
253
237
|
update_last_error(ZTensorError::Other("Invalid UTF-8 path".into()));
|
|
254
238
|
return ptr::null_mut();
|
|
@@ -264,14 +248,6 @@ pub extern "C" fn ztensor_writer_create(path_str: *const c_char) -> *mut CZTenso
|
|
|
264
248
|
}
|
|
265
249
|
}
|
|
266
250
|
|
|
267
|
-
#[unsafe(no_mangle)]
|
|
268
|
-
pub extern "C" fn ztensor_writer_free(writer_ptr: *mut CZTensorWriter) {
|
|
269
|
-
if !writer_ptr.is_null() {
|
|
270
|
-
let _ = unsafe { Box::from_raw(writer_ptr) };
|
|
271
|
-
// Dropping the writer will close the file.
|
|
272
|
-
}
|
|
273
|
-
}
|
|
274
|
-
|
|
275
251
|
#[unsafe(no_mangle)]
|
|
276
252
|
pub extern "C" fn ztensor_writer_add_tensor(
|
|
277
253
|
writer_ptr: *mut CZTensorWriter,
|
|
@@ -282,16 +258,26 @@ pub extern "C" fn ztensor_writer_add_tensor(
|
|
|
282
258
|
data_ptr: *const c_uchar,
|
|
283
259
|
data_len: size_t,
|
|
284
260
|
) -> c_int {
|
|
285
|
-
let writer =
|
|
261
|
+
let writer = ztensor_handle!(mut writer_ptr, -1);
|
|
286
262
|
let name = unsafe { CStr::from_ptr(name_str).to_str().unwrap() };
|
|
287
263
|
let shape = unsafe { slice::from_raw_parts(shape_ptr, shape_len) };
|
|
288
264
|
let dtype_str = unsafe { CStr::from_ptr(dtype_str).to_str().unwrap() };
|
|
289
265
|
let data = unsafe { slice::from_raw_parts(data_ptr, data_len) };
|
|
290
266
|
|
|
291
267
|
let dtype = match dtype_str {
|
|
268
|
+
"float64" => DType::Float64,
|
|
292
269
|
"float32" => DType::Float32,
|
|
270
|
+
"float16" => DType::Float16,
|
|
271
|
+
"bfloat16" => DType::BFloat16,
|
|
272
|
+
"int64" => DType::Int64,
|
|
273
|
+
"int32" => DType::Int32,
|
|
274
|
+
"int16" => DType::Int16,
|
|
275
|
+
"int8" => DType::Int8,
|
|
276
|
+
"uint64" => DType::Uint64,
|
|
277
|
+
"uint32" => DType::Uint32,
|
|
278
|
+
"uint16" => DType::Uint16,
|
|
293
279
|
"uint8" => DType::Uint8,
|
|
294
|
-
|
|
280
|
+
"bool" => DType::Bool,
|
|
295
281
|
_ => {
|
|
296
282
|
update_last_error(ZTensorError::UnsupportedDType(dtype_str.to_string()));
|
|
297
283
|
return -1;
|
|
@@ -305,7 +291,7 @@ pub extern "C" fn ztensor_writer_add_tensor(
|
|
|
305
291
|
Layout::Dense,
|
|
306
292
|
Encoding::Raw,
|
|
307
293
|
data.to_vec(),
|
|
308
|
-
Some(DataEndianness::Little),
|
|
294
|
+
Some(DataEndianness::Little),
|
|
309
295
|
ChecksumAlgorithm::None,
|
|
310
296
|
None,
|
|
311
297
|
);
|
|
@@ -324,8 +310,6 @@ pub extern "C" fn ztensor_writer_finalize(writer_ptr: *mut CZTensorWriter) -> c_
|
|
|
324
310
|
if writer_ptr.is_null() {
|
|
325
311
|
return -1;
|
|
326
312
|
}
|
|
327
|
-
// `from_raw` takes ownership and will drop the writer when it goes out of scope.
|
|
328
|
-
// The writer's `drop` implementation should handle finalization.
|
|
329
313
|
let writer = unsafe { Box::from_raw(writer_ptr) };
|
|
330
314
|
match writer.finalize() {
|
|
331
315
|
Ok(_) => 0,
|
|
@@ -336,13 +320,147 @@ pub extern "C" fn ztensor_writer_finalize(writer_ptr: *mut CZTensorWriter) -> c_
|
|
|
336
320
|
}
|
|
337
321
|
}
|
|
338
322
|
|
|
339
|
-
// ---
|
|
323
|
+
// --- Metadata API ---
|
|
324
|
+
|
|
325
|
+
#[unsafe(no_mangle)]
|
|
326
|
+
pub extern "C" fn ztensor_metadata_get_name(metadata_ptr: *const CTensorMetadata) -> *mut c_char {
|
|
327
|
+
let metadata = ztensor_handle!(metadata_ptr);
|
|
328
|
+
to_cstring(metadata.name.clone())
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
#[unsafe(no_mangle)]
|
|
332
|
+
pub extern "C" fn ztensor_metadata_get_dtype_str(
|
|
333
|
+
metadata_ptr: *const CTensorMetadata,
|
|
334
|
+
) -> *mut c_char {
|
|
335
|
+
let metadata = ztensor_handle!(metadata_ptr);
|
|
336
|
+
to_cstring(metadata.dtype.to_string_key())
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
#[unsafe(no_mangle)]
|
|
340
|
+
pub extern "C" fn ztensor_metadata_get_offset(metadata_ptr: *const CTensorMetadata) -> u64 {
|
|
341
|
+
let metadata = ztensor_handle!(metadata_ptr, 0);
|
|
342
|
+
metadata.offset
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
#[unsafe(no_mangle)]
|
|
346
|
+
pub extern "C" fn ztensor_metadata_get_size(metadata_ptr: *const CTensorMetadata) -> u64 {
|
|
347
|
+
let metadata = ztensor_handle!(metadata_ptr, 0);
|
|
348
|
+
metadata.size
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
#[unsafe(no_mangle)]
|
|
352
|
+
pub extern "C" fn ztensor_metadata_get_layout_str(
|
|
353
|
+
metadata_ptr: *const CTensorMetadata,
|
|
354
|
+
) -> *mut c_char {
|
|
355
|
+
let metadata = ztensor_handle!(metadata_ptr);
|
|
356
|
+
to_cstring(format!("{:?}", metadata.layout).to_lowercase())
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
#[unsafe(no_mangle)]
|
|
360
|
+
pub extern "C" fn ztensor_metadata_get_encoding_str(
|
|
361
|
+
metadata_ptr: *const CTensorMetadata,
|
|
362
|
+
) -> *mut c_char {
|
|
363
|
+
let metadata = ztensor_handle!(metadata_ptr);
|
|
364
|
+
to_cstring(format!("{:?}", metadata.encoding).to_lowercase())
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
#[unsafe(no_mangle)]
|
|
368
|
+
pub extern "C" fn ztensor_metadata_get_data_endianness_str(
|
|
369
|
+
metadata_ptr: *const CTensorMetadata,
|
|
370
|
+
) -> *mut c_char {
|
|
371
|
+
let metadata = ztensor_handle!(metadata_ptr);
|
|
372
|
+
match &metadata.data_endianness {
|
|
373
|
+
Some(endianness) => to_cstring(format!("{:?}", endianness).to_lowercase()),
|
|
374
|
+
None => ptr::null_mut(),
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
#[unsafe(no_mangle)]
|
|
379
|
+
pub extern "C" fn ztensor_metadata_get_checksum_str(
|
|
380
|
+
metadata_ptr: *const CTensorMetadata,
|
|
381
|
+
) -> *mut c_char {
|
|
382
|
+
let metadata = ztensor_handle!(metadata_ptr);
|
|
383
|
+
match &metadata.checksum {
|
|
384
|
+
Some(s) => to_cstring(s.clone()),
|
|
385
|
+
None => ptr::null_mut(),
|
|
386
|
+
}
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
#[unsafe(no_mangle)]
|
|
390
|
+
pub extern "C" fn ztensor_metadata_get_shape_len(metadata_ptr: *const CTensorMetadata) -> size_t {
|
|
391
|
+
let metadata = ztensor_handle!(metadata_ptr, 0);
|
|
392
|
+
metadata.shape.len()
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
#[unsafe(no_mangle)]
|
|
396
|
+
pub extern "C" fn ztensor_metadata_get_shape_data(
|
|
397
|
+
metadata_ptr: *const CTensorMetadata,
|
|
398
|
+
) -> *mut u64 {
|
|
399
|
+
let metadata = ztensor_handle!(metadata_ptr);
|
|
400
|
+
let mut shape_vec = metadata.shape.clone();
|
|
401
|
+
let ptr = shape_vec.as_mut_ptr();
|
|
402
|
+
std::mem::forget(shape_vec);
|
|
403
|
+
ptr
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
// --- Memory Management API ---
|
|
407
|
+
|
|
408
|
+
#[unsafe(no_mangle)]
|
|
409
|
+
pub extern "C" fn ztensor_reader_free(reader_ptr: *mut CZTensorReader) {
|
|
410
|
+
if !reader_ptr.is_null() {
|
|
411
|
+
let _ = unsafe { Box::from_raw(reader_ptr) };
|
|
412
|
+
}
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
#[unsafe(no_mangle)]
|
|
416
|
+
pub extern "C" fn ztensor_writer_free(writer_ptr: *mut CZTensorWriter) {
|
|
417
|
+
if !writer_ptr.is_null() {
|
|
418
|
+
let _ = unsafe { Box::from_raw(writer_ptr) };
|
|
419
|
+
}
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
#[unsafe(no_mangle)]
|
|
423
|
+
pub extern "C" fn ztensor_metadata_free(metadata_ptr: *mut CTensorMetadata) {
|
|
424
|
+
if !metadata_ptr.is_null() {
|
|
425
|
+
let _ = unsafe { Box::from_raw(metadata_ptr) };
|
|
426
|
+
}
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
#[unsafe(no_mangle)]
|
|
430
|
+
pub extern "C" fn ztensor_free_tensor_view(view_ptr: *mut CTensorDataView) {
|
|
431
|
+
if !view_ptr.is_null() {
|
|
432
|
+
unsafe {
|
|
433
|
+
let view = Box::from_raw(view_ptr);
|
|
434
|
+
// This reconstitutes the `Vec<u8>` and allows it to be dropped.
|
|
435
|
+
let _ = Box::from_raw(view._owner as *mut Vec<u8>);
|
|
436
|
+
}
|
|
437
|
+
}
|
|
438
|
+
}
|
|
340
439
|
|
|
341
440
|
#[unsafe(no_mangle)]
|
|
342
441
|
pub extern "C" fn ztensor_free_string(s: *mut c_char) {
|
|
343
442
|
if !s.is_null() {
|
|
344
|
-
unsafe {
|
|
345
|
-
|
|
346
|
-
|
|
443
|
+
let _ = unsafe { CString::from_raw(s) };
|
|
444
|
+
}
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
#[unsafe(no_mangle)]
|
|
448
|
+
pub extern "C" fn ztensor_free_string_array(arr_ptr: *mut CStringArray) {
|
|
449
|
+
if arr_ptr.is_null() {
|
|
450
|
+
return;
|
|
451
|
+
}
|
|
452
|
+
unsafe {
|
|
453
|
+
let arr = Box::from_raw(arr_ptr);
|
|
454
|
+
let strings = Vec::from_raw_parts(arr.strings, arr.len, arr.len);
|
|
455
|
+
for s_ptr in strings {
|
|
456
|
+
let _ = CString::from_raw(s_ptr);
|
|
457
|
+
}
|
|
458
|
+
}
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
#[unsafe(no_mangle)]
|
|
462
|
+
pub extern "C" fn ztensor_free_u64_array(ptr: *mut u64, len: size_t) {
|
|
463
|
+
if !ptr.is_null() {
|
|
464
|
+
let _ = unsafe { Vec::from_raw_parts(ptr, len, len) };
|
|
347
465
|
}
|
|
348
466
|
}
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|