cellarr-array 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cellarr-array might be problematic. Click here for more details.
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/.gitignore +2 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/CHANGELOG.md +6 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/PKG-INFO +4 -1
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/setup.cfg +3 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/src/cellarr_array/__init__.py +2 -4
- cellarr_array-0.2.0/src/cellarr_array/core/__init__.py +3 -0
- cellarr_array-0.1.0/src/cellarr_array/cellarray_base.py → cellarr_array-0.2.0/src/cellarr_array/core/base.py +1 -1
- cellarr_array-0.1.0/src/cellarr_array/cellarray_dense.py → cellarr_array-0.2.0/src/cellarr_array/core/dense.py +2 -3
- {cellarr_array-0.1.0/src/cellarr_array → cellarr_array-0.2.0/src/cellarr_array/core}/helpers.py +77 -43
- cellarr_array-0.1.0/src/cellarr_array/cellarray_sparse.py → cellarr_array-0.2.0/src/cellarr_array/core/sparse.py +11 -16
- cellarr_array-0.2.0/src/cellarr_array/dataloaders/__init__.py +3 -0
- cellarr_array-0.2.0/src/cellarr_array/dataloaders/denseloader.py +198 -0
- cellarr_array-0.2.0/src/cellarr_array/dataloaders/iterabledataloader.py +320 -0
- cellarr_array-0.2.0/src/cellarr_array/dataloaders/sparseloader.py +230 -0
- cellarr_array-0.2.0/src/cellarr_array/dataloaders/utils.py +26 -0
- cellarr_array-0.2.0/src/cellarr_array/utils/__init__.py +3 -0
- cellarr_array-0.2.0/src/cellarr_array/utils/mock.py +167 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/src/cellarr_array.egg-info/PKG-INFO +4 -1
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/src/cellarr_array.egg-info/SOURCES.txt +15 -5
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/src/cellarr_array.egg-info/requires.txt +4 -0
- cellarr_array-0.2.0/tests/conftest.py +233 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/tests/test_all.py +1 -1
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/tests/test_dense.py +4 -4
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/tests/test_helpers.py +17 -5
- cellarr_array-0.2.0/tests/test_iterable_loader.py +288 -0
- cellarr_array-0.2.0/tests/test_map_loader.py +289 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/tests/test_sparse.py +1 -1
- cellarr_array-0.1.0/tests/conftest.py +0 -91
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/.coveragerc +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/.github/workflows/publish-pypi.yml +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/.github/workflows/run-tests.yml +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/.pre-commit-config.yaml +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/.readthedocs.yml +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/AUTHORS.md +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/CONTRIBUTING.md +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/LICENSE.txt +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/README.md +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/docs/Makefile +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/docs/_static/.gitignore +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/docs/authors.md +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/docs/changelog.md +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/docs/conf.py +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/docs/contributing.md +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/docs/index.md +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/docs/license.md +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/docs/readme.md +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/docs/requirements.txt +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/pyproject.toml +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/setup.py +0 -0
- {cellarr_array-0.1.0/src/cellarr_array → cellarr_array-0.2.0/src/cellarr_array/utils}/config.py +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/src/cellarr_array.egg-info/dependency_links.txt +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/src/cellarr_array.egg-info/not-zip-safe +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/src/cellarr_array.egg-info/top_level.txt +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/tests/test_inmemory.py +0 -0
- {cellarr_array-0.1.0 → cellarr_array-0.2.0}/tox.ini +0 -0
|
@@ -1,5 +1,11 @@
|
|
|
1
1
|
# Changelog
|
|
2
2
|
|
|
3
|
+
## Version 0.2.0
|
|
4
|
+
|
|
5
|
+
- Dataloaders for sparse and dense arrays, We provide templates for both map and Iterable style dataloaders. Users are expected the caveats of both of these approaches.
|
|
6
|
+
- Fixed a bug with slicing on 1D arrays and many improvements for optimizing slicing parameters.
|
|
7
|
+
- Update documentation and tests.
|
|
8
|
+
|
|
3
9
|
## Version 0.1.0
|
|
4
10
|
|
|
5
11
|
- Support cellarr-arrays on user provided tiledb array objects.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cellarr-array
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Base class for handling TileDB backed arrays.
|
|
5
5
|
Home-page: https://github.com/cellarr/cellarr-array
|
|
6
6
|
Author: Jayaram Kancherla
|
|
@@ -16,10 +16,13 @@ Requires-Dist: importlib-metadata; python_version < "3.8"
|
|
|
16
16
|
Requires-Dist: tiledb
|
|
17
17
|
Requires-Dist: numpy
|
|
18
18
|
Requires-Dist: scipy
|
|
19
|
+
Provides-Extra: optional
|
|
20
|
+
Requires-Dist: torch; extra == "optional"
|
|
19
21
|
Provides-Extra: testing
|
|
20
22
|
Requires-Dist: setuptools; extra == "testing"
|
|
21
23
|
Requires-Dist: pytest; extra == "testing"
|
|
22
24
|
Requires-Dist: pytest-cov; extra == "testing"
|
|
25
|
+
Requires-Dist: torch; extra == "testing"
|
|
23
26
|
Dynamic: license-file
|
|
24
27
|
|
|
25
28
|
[](https://pypi.org/project/cellarr-array/)
|
|
@@ -15,7 +15,5 @@ except PackageNotFoundError: # pragma: no cover
|
|
|
15
15
|
finally:
|
|
16
16
|
del version, PackageNotFoundError
|
|
17
17
|
|
|
18
|
-
from .
|
|
19
|
-
from .
|
|
20
|
-
from .cellarray_sparse import SparseCellArray
|
|
21
|
-
from .helpers import create_cellarray, SliceHelper
|
|
18
|
+
from .core import DenseCellArray, SparseCellArray
|
|
19
|
+
from .utils import CellArrConfig, ConsolidationConfig, create_cellarray
|
|
@@ -7,7 +7,7 @@ from typing import List, Tuple, Union
|
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
|
-
from .
|
|
10
|
+
from .base import CellArray
|
|
11
11
|
from .helpers import SliceHelper
|
|
12
12
|
|
|
13
13
|
__author__ = "Jayaram Kancherla"
|
|
@@ -92,7 +92,6 @@ class DenseCellArray(CellArray):
|
|
|
92
92
|
if len(data.shape) != self.ndim:
|
|
93
93
|
raise ValueError(f"Data dimensions {data.shape} don't match array dimensions {self.shape}.")
|
|
94
94
|
|
|
95
|
-
# Check bounds
|
|
96
95
|
end_row = start_row + data.shape[0]
|
|
97
96
|
if end_row > self.shape[0]:
|
|
98
97
|
raise ValueError(
|
|
@@ -102,7 +101,6 @@ class DenseCellArray(CellArray):
|
|
|
102
101
|
if self.ndim == 2 and data.shape[1] != self.shape[1]:
|
|
103
102
|
raise ValueError(f"Data columns {data.shape[1]} don't match array columns {self.shape[1]}.")
|
|
104
103
|
|
|
105
|
-
# Construct write region
|
|
106
104
|
if self.ndim == 1:
|
|
107
105
|
write_region = slice(start_row, end_row)
|
|
108
106
|
else: # 2D
|
|
@@ -110,4 +108,5 @@ class DenseCellArray(CellArray):
|
|
|
110
108
|
|
|
111
109
|
# write_data = {self._attr: data} if len(self.attr_names) > 1 else data
|
|
112
110
|
with self.open_array(mode="w") as array:
|
|
111
|
+
print("write_region", write_region)
|
|
113
112
|
array[write_region] = data
|
{cellarr_array-0.1.0/src/cellarr_array → cellarr_array-0.2.0/src/cellarr_array/core}/helpers.py
RENAMED
|
@@ -8,7 +8,7 @@ from typing import List, Optional, Tuple, Union
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tiledb
|
|
10
10
|
|
|
11
|
-
from .config import CellArrConfig
|
|
11
|
+
from ..utils.config import CellArrConfig
|
|
12
12
|
|
|
13
13
|
__author__ = "Jayaram Kancherla"
|
|
14
14
|
__copyright__ = "Jayaram Kancherla"
|
|
@@ -52,7 +52,7 @@ def create_cellarray(
|
|
|
52
52
|
Optional list of dimension names.
|
|
53
53
|
|
|
54
54
|
dim_dtypes:
|
|
55
|
-
Optional list of dimension dtypes.
|
|
55
|
+
Optional list of dimension dtypes. Defaults to numpy's uint32.
|
|
56
56
|
|
|
57
57
|
attr_name:
|
|
58
58
|
Name of the data attribute.
|
|
@@ -67,29 +67,28 @@ def create_cellarray(
|
|
|
67
67
|
ValueError: If dimensions are invalid or inputs are inconsistent.
|
|
68
68
|
"""
|
|
69
69
|
config = config or CellArrConfig()
|
|
70
|
+
tiledb_ctx = tiledb.Config(config.ctx_config) if config.ctx_config else None
|
|
70
71
|
|
|
71
72
|
if attr_dtype is None:
|
|
72
73
|
attr_dtype = np.float32
|
|
73
74
|
if isinstance(attr_dtype, str):
|
|
74
75
|
attr_dtype = np.dtype(attr_dtype)
|
|
75
76
|
|
|
76
|
-
# Require either shape or dim_dtypes
|
|
77
77
|
if shape is None and dim_dtypes is None:
|
|
78
78
|
raise ValueError("Either 'shape' or 'dim_dtypes' must be provided.")
|
|
79
79
|
|
|
80
80
|
if shape is not None:
|
|
81
81
|
if len(shape) not in (1, 2):
|
|
82
|
-
raise ValueError("
|
|
82
|
+
raise ValueError("Shape must have 1 or 2 dimensions.")
|
|
83
83
|
|
|
84
84
|
# Set dimension dtypes, defaults to numpy uint32
|
|
85
85
|
if dim_dtypes is None:
|
|
86
86
|
dim_dtypes = [np.uint32] * len(shape)
|
|
87
87
|
else:
|
|
88
88
|
if len(dim_dtypes) not in (1, 2):
|
|
89
|
-
raise ValueError("
|
|
89
|
+
raise ValueError("Array must have 1 or 2 dimensions.")
|
|
90
90
|
dim_dtypes = [np.dtype(dt) if isinstance(dt, str) else dt for dt in dim_dtypes]
|
|
91
91
|
|
|
92
|
-
# Calculate shape from dtypes if needed
|
|
93
92
|
if shape is None:
|
|
94
93
|
shape = tuple(np.iinfo(dt).max if np.issubdtype(dt, np.integer) else None for dt in dim_dtypes)
|
|
95
94
|
if None in shape:
|
|
@@ -97,7 +96,6 @@ def create_cellarray(
|
|
|
97
96
|
np.iinfo(dt).max if s is None and np.issubdtype(dt, np.integer) else s for s, dt in zip(shape, dim_dtypes)
|
|
98
97
|
)
|
|
99
98
|
|
|
100
|
-
# Set dimension names
|
|
101
99
|
if dim_names is None:
|
|
102
100
|
dim_names = [f"dim_{i}" for i in range(len(shape))]
|
|
103
101
|
|
|
@@ -107,40 +105,43 @@ def create_cellarray(
|
|
|
107
105
|
|
|
108
106
|
dom = tiledb.Domain(
|
|
109
107
|
*[
|
|
110
|
-
tiledb.Dim(
|
|
108
|
+
tiledb.Dim(
|
|
109
|
+
name=name,
|
|
110
|
+
# supporting empty dimensions
|
|
111
|
+
domain=(0, 0 if s == 0 else s - 1),
|
|
112
|
+
tile=min(1 if s == 0 else s // 2, config.tile_capacity // 2),
|
|
113
|
+
dtype=dt,
|
|
114
|
+
)
|
|
111
115
|
for name, s, dt in zip(dim_names, shape, dim_dtypes)
|
|
112
116
|
],
|
|
113
|
-
ctx=
|
|
117
|
+
ctx=tiledb_ctx,
|
|
114
118
|
)
|
|
115
|
-
|
|
116
|
-
attr = tiledb.Attr(
|
|
119
|
+
attr_obj = tiledb.Attr(
|
|
117
120
|
name=attr_name,
|
|
118
121
|
dtype=attr_dtype,
|
|
119
122
|
filters=config.attrs_filters.get(attr_name, config.attrs_filters.get("", None)),
|
|
123
|
+
ctx=tiledb_ctx,
|
|
120
124
|
)
|
|
121
|
-
|
|
122
125
|
schema = tiledb.ArraySchema(
|
|
123
126
|
domain=dom,
|
|
124
|
-
attrs=[
|
|
127
|
+
attrs=[attr_obj],
|
|
125
128
|
cell_order=config.cell_order,
|
|
126
129
|
tile_order=config.tile_order,
|
|
127
130
|
sparse=sparse,
|
|
128
131
|
coords_filters=config.coords_filters,
|
|
129
132
|
offsets_filters=config.offsets_filters,
|
|
130
|
-
ctx=
|
|
133
|
+
ctx=tiledb_ctx,
|
|
131
134
|
)
|
|
132
|
-
|
|
133
|
-
tiledb.Array.create(uri, schema)
|
|
135
|
+
tiledb.Array.create(uri, schema, ctx=tiledb_ctx)
|
|
134
136
|
|
|
135
137
|
# Import here to avoid circular imports
|
|
136
|
-
from .
|
|
137
|
-
from .
|
|
138
|
+
from .dense import DenseCellArray
|
|
139
|
+
from .sparse import SparseCellArray
|
|
138
140
|
|
|
139
|
-
# Return appropriate array type
|
|
140
141
|
return (
|
|
141
|
-
SparseCellArray(uri=uri, attr=attr_name, mode=mode)
|
|
142
|
+
SparseCellArray(uri=uri, attr=attr_name, mode=mode, config_or_context=tiledb_ctx)
|
|
142
143
|
if sparse
|
|
143
|
-
else DenseCellArray(uri=uri, attr=attr_name, mode=mode)
|
|
144
|
+
else DenseCellArray(uri=uri, attr=attr_name, mode=mode, config_or_context=tiledb_ctx)
|
|
144
145
|
)
|
|
145
146
|
|
|
146
147
|
|
|
@@ -149,19 +150,27 @@ class SliceHelper:
|
|
|
149
150
|
|
|
150
151
|
@staticmethod
|
|
151
152
|
def is_contiguous_indices(indices: List[int]) -> Optional[slice]:
|
|
152
|
-
"""Check if indices can be represented as a contiguous slice."""
|
|
153
153
|
if not indices:
|
|
154
154
|
return None
|
|
155
155
|
|
|
156
|
-
|
|
156
|
+
sorted_indices = sorted(list(set(indices)))
|
|
157
|
+
if not sorted_indices:
|
|
158
|
+
return None
|
|
159
|
+
|
|
160
|
+
if len(sorted_indices) == 1:
|
|
161
|
+
return slice(sorted_indices[0], sorted_indices[0] + 1, None)
|
|
162
|
+
|
|
163
|
+
diffs = np.diff(sorted_indices)
|
|
157
164
|
if np.all(diffs == 1):
|
|
158
|
-
return slice(
|
|
165
|
+
return slice(sorted_indices[0], sorted_indices[-1] + 1, None)
|
|
166
|
+
|
|
159
167
|
return None
|
|
160
168
|
|
|
161
169
|
@staticmethod
|
|
162
|
-
def normalize_index(
|
|
170
|
+
def normalize_index(
|
|
171
|
+
idx: Union[int, range, slice, List[int], EllipsisType], dim_size: int
|
|
172
|
+
) -> Union[slice, List[int], EllipsisType]:
|
|
163
173
|
"""Normalize index to handle negative indices and ensure consistency."""
|
|
164
|
-
|
|
165
174
|
if isinstance(idx, EllipsisType):
|
|
166
175
|
return idx
|
|
167
176
|
|
|
@@ -170,36 +179,61 @@ class SliceHelper:
|
|
|
170
179
|
idx = slice(idx.start, idx.stop, idx.step)
|
|
171
180
|
|
|
172
181
|
if isinstance(idx, slice):
|
|
173
|
-
start = idx.start
|
|
174
|
-
stop = idx.stop
|
|
182
|
+
start = idx.start
|
|
183
|
+
stop = idx.stop
|
|
175
184
|
step = idx.step
|
|
176
185
|
|
|
186
|
+
# Resolve None to full dimension slice parts
|
|
187
|
+
if start is None:
|
|
188
|
+
start = 0
|
|
189
|
+
|
|
190
|
+
if stop is None:
|
|
191
|
+
stop = dim_size
|
|
192
|
+
|
|
177
193
|
# Handle negative indices
|
|
178
194
|
if start < 0:
|
|
179
|
-
start
|
|
180
|
-
|
|
195
|
+
start += dim_size
|
|
181
196
|
if stop < 0:
|
|
182
|
-
stop
|
|
197
|
+
stop += dim_size
|
|
183
198
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
if
|
|
187
|
-
|
|
199
|
+
# slice allows start > dim_size or stop < 0 to result in empty slices.
|
|
200
|
+
# Note: start == dim_size is OK for empty slice like arr[dim_size:]
|
|
201
|
+
if start < 0 or (start >= dim_size and dim_size > 0):
|
|
202
|
+
if not (start == dim_size and (step is None or step > 0)):
|
|
203
|
+
if start >= dim_size:
|
|
204
|
+
raise IndexError(
|
|
205
|
+
f"Start index {idx.start if idx.start is not None else 'None'} results in {start}, which is out of bounds for dimension size {dim_size}."
|
|
206
|
+
)
|
|
188
207
|
|
|
189
|
-
|
|
208
|
+
# Clamping slice arguments to dimensions
|
|
209
|
+
stop = min(stop, dim_size)
|
|
210
|
+
start = max(0, start)
|
|
190
211
|
|
|
212
|
+
return slice(start, stop, step)
|
|
191
213
|
elif isinstance(idx, list):
|
|
214
|
+
if not idx:
|
|
215
|
+
return []
|
|
216
|
+
|
|
192
217
|
norm_idx = [i if i >= 0 else dim_size + i for i in idx]
|
|
193
218
|
if any(i < 0 or i >= dim_size for i in norm_idx):
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
219
|
+
oob_indices = [orig_i for orig_i, norm_i in zip(idx, norm_idx) if not (0 <= norm_i < dim_size)]
|
|
220
|
+
raise IndexError(
|
|
221
|
+
f"List indices {oob_indices} (original values) are out of bounds for dimension size {dim_size}."
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# TileDB multi_index usually returns data sorted by coordinates
|
|
225
|
+
return sorted(list(set(norm_idx)))
|
|
226
|
+
elif isinstance(idx, (int, np.integer)):
|
|
227
|
+
norm_idx = int(idx)
|
|
228
|
+
if norm_idx < 0:
|
|
229
|
+
norm_idx += dim_size
|
|
230
|
+
|
|
231
|
+
if not (0 <= norm_idx < dim_size):
|
|
201
232
|
raise IndexError(f"Index {idx} out of bounds for dimension size {dim_size}")
|
|
233
|
+
|
|
202
234
|
return slice(norm_idx, norm_idx + 1, None)
|
|
235
|
+
else:
|
|
236
|
+
raise TypeError(f"Index type {type(idx)} not supported for normalization.")
|
|
203
237
|
|
|
204
238
|
|
|
205
239
|
def create_group(output_path, group_name):
|
|
@@ -9,8 +9,8 @@ import numpy as np
|
|
|
9
9
|
import tiledb
|
|
10
10
|
from scipy import sparse
|
|
11
11
|
|
|
12
|
-
from .cellarray_base import CellArray
|
|
13
12
|
from .helpers import SliceHelper
|
|
13
|
+
from .base import CellArray
|
|
14
14
|
|
|
15
15
|
__author__ = "Jayaram Kancherla"
|
|
16
16
|
__copyright__ = "Jayaram Kancherla"
|
|
@@ -28,7 +28,7 @@ class SparseCellArray(CellArray):
|
|
|
28
28
|
mode: Optional[Literal["r", "w", "d", "m"]] = None,
|
|
29
29
|
config_or_context: Optional[Union[tiledb.Config, tiledb.Ctx]] = None,
|
|
30
30
|
return_sparse: bool = True,
|
|
31
|
-
|
|
31
|
+
sparse_format: Union[sparse.csr_matrix, sparse.csc_matrix] = sparse.csr_matrix,
|
|
32
32
|
validate: bool = True,
|
|
33
33
|
**kwargs,
|
|
34
34
|
):
|
|
@@ -66,7 +66,7 @@ class SparseCellArray(CellArray):
|
|
|
66
66
|
Whether to return a sparse representation of the data when object is sliced.
|
|
67
67
|
Default is to return a dictionary that contains coordinates and values.
|
|
68
68
|
|
|
69
|
-
|
|
69
|
+
sparse_format:
|
|
70
70
|
Format to return, defaults to csr_matrix.
|
|
71
71
|
|
|
72
72
|
validate:
|
|
@@ -86,7 +86,7 @@ class SparseCellArray(CellArray):
|
|
|
86
86
|
)
|
|
87
87
|
|
|
88
88
|
self.return_sparse = return_sparse
|
|
89
|
-
self.
|
|
89
|
+
self.sparse_format = sparse.csr_matrix if sparse_format is None else sparse_format
|
|
90
90
|
|
|
91
91
|
def _validate_matrix_dims(self, data: sparse.spmatrix) -> Tuple[sparse.coo_matrix, bool]:
|
|
92
92
|
"""Validate and adjust matrix dimensions if needed.
|
|
@@ -126,7 +126,7 @@ class SparseCellArray(CellArray):
|
|
|
126
126
|
shape.append(idx.stop - (idx.start or 0))
|
|
127
127
|
elif isinstance(idx, list):
|
|
128
128
|
shape.append(len(set(idx)))
|
|
129
|
-
else:
|
|
129
|
+
else:
|
|
130
130
|
shape.append(1)
|
|
131
131
|
|
|
132
132
|
# Always return (n,1) shape for CSR matrix
|
|
@@ -140,20 +140,17 @@ class SparseCellArray(CellArray):
|
|
|
140
140
|
"""Convert TileDB result to CSR format or dense array."""
|
|
141
141
|
data = result[self._attr]
|
|
142
142
|
|
|
143
|
-
# empty result
|
|
144
143
|
if len(data) == 0:
|
|
145
144
|
print("is emoty")
|
|
146
145
|
if not self.return_sparse:
|
|
147
146
|
return result
|
|
148
147
|
else:
|
|
149
|
-
# For COO output, return empty sparse matrix
|
|
150
148
|
if self.ndim == 1:
|
|
151
|
-
matrix = self.
|
|
149
|
+
matrix = self.sparse_format((1, shape[0]))
|
|
152
150
|
return matrix[:, key[0]]
|
|
153
151
|
|
|
154
|
-
return self.
|
|
152
|
+
return self.sparse_format(shape)[key]
|
|
155
153
|
|
|
156
|
-
# Get coordinates
|
|
157
154
|
coords = []
|
|
158
155
|
for dim_name in self.dim_names:
|
|
159
156
|
dim_coords = result[dim_name]
|
|
@@ -164,11 +161,12 @@ class SparseCellArray(CellArray):
|
|
|
164
161
|
coords = [np.zeros_like(coords[0]), coords[0]]
|
|
165
162
|
shape = (1, shape[0])
|
|
166
163
|
|
|
167
|
-
# Create sparse matrix
|
|
168
164
|
matrix = sparse.coo_matrix((data, tuple(coords)), shape=shape)
|
|
169
|
-
|
|
165
|
+
|
|
166
|
+
sliced = matrix
|
|
167
|
+
if self.sparse_format in (sparse.csr_matrix, sparse.csr_array):
|
|
170
168
|
sliced = matrix.tocsr()
|
|
171
|
-
elif self.
|
|
169
|
+
elif self.sparse_format in (sparse.csc_matrix, sparse.csc_array):
|
|
172
170
|
sliced = matrix.tocsc()
|
|
173
171
|
|
|
174
172
|
if self.ndim == 1:
|
|
@@ -200,7 +198,6 @@ class SparseCellArray(CellArray):
|
|
|
200
198
|
if all(isinstance(idx, slice) for idx in optimized_key):
|
|
201
199
|
return self._direct_slice(tuple(optimized_key))
|
|
202
200
|
|
|
203
|
-
# For mixed slice-list queries, adjust slice bounds
|
|
204
201
|
tiledb_key = []
|
|
205
202
|
for idx in key:
|
|
206
203
|
if isinstance(idx, slice):
|
|
@@ -239,10 +236,8 @@ class SparseCellArray(CellArray):
|
|
|
239
236
|
if not sparse.issparse(data):
|
|
240
237
|
raise TypeError("Input must be a scipy sparse matrix.")
|
|
241
238
|
|
|
242
|
-
# Validate and adjust dimensions
|
|
243
239
|
coo_data, is_1d = self._validate_matrix_dims(data)
|
|
244
240
|
|
|
245
|
-
# Check bounds
|
|
246
241
|
end_row = start_row + coo_data.shape[0]
|
|
247
242
|
if end_row > self.shape[0]:
|
|
248
243
|
raise ValueError(
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from warnings import warn
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import tiledb
|
|
6
|
+
import torch
|
|
7
|
+
from torch.utils.data import DataLoader, Dataset
|
|
8
|
+
|
|
9
|
+
from ..core.dense import DenseCellArray
|
|
10
|
+
|
|
11
|
+
__author__ = "Jayaram Kancherla"
|
|
12
|
+
__copyright__ = "Jayaram Kancherla"
|
|
13
|
+
__license__ = "MIT"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DenseArrayDataset(Dataset):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
array_uri: str,
|
|
20
|
+
attribute_name: str = "data",
|
|
21
|
+
num_rows: Optional[int] = None,
|
|
22
|
+
num_columns: Optional[int] = None,
|
|
23
|
+
cellarr_ctx_config: Optional[dict] = None,
|
|
24
|
+
transform=None,
|
|
25
|
+
):
|
|
26
|
+
"""PyTorch Dataset for dense TileDB arrays accessed via DenseCellArray.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
array_uri:
|
|
30
|
+
URI of the TileDB dense array.
|
|
31
|
+
|
|
32
|
+
attribute_name:
|
|
33
|
+
Name of the attribute to read from.
|
|
34
|
+
|
|
35
|
+
num_rows:
|
|
36
|
+
Total number of rows in the dataset.
|
|
37
|
+
If None, will infer from `array.shape[0]`.
|
|
38
|
+
|
|
39
|
+
num_columns:
|
|
40
|
+
The number of columns in the dataset.
|
|
41
|
+
If None, will attempt to infer `from array.shape[1]`.
|
|
42
|
+
|
|
43
|
+
cellarr_ctx_config:
|
|
44
|
+
Optional TileDB context configuration dict for CellArray.
|
|
45
|
+
|
|
46
|
+
transform:
|
|
47
|
+
Optional transform to be applied on a sample.
|
|
48
|
+
"""
|
|
49
|
+
self.array_uri = array_uri
|
|
50
|
+
self.attribute_name = attribute_name
|
|
51
|
+
self.cellarr_ctx_config = cellarr_ctx_config
|
|
52
|
+
self.transform = transform
|
|
53
|
+
self.cell_array_instance = None
|
|
54
|
+
|
|
55
|
+
if num_rows is not None and num_columns is not None:
|
|
56
|
+
self._len = num_rows
|
|
57
|
+
self.num_columns = num_columns
|
|
58
|
+
else:
|
|
59
|
+
# Infer the array shape
|
|
60
|
+
print(f"Dataset '{array_uri}': num_rows or num_columns not provided. Probing array...")
|
|
61
|
+
init_ctx_config = tiledb.Config(self.cellarr_ctx_config) if self.cellarr_ctx_config else None
|
|
62
|
+
try:
|
|
63
|
+
temp_arr = DenseCellArray(
|
|
64
|
+
uri=self.array_uri, attr=self.attribute_name, config_or_context=init_ctx_config
|
|
65
|
+
)
|
|
66
|
+
if temp_arr.ndim == 1:
|
|
67
|
+
self._len = num_rows if num_rows is not None else temp_arr.shape[0]
|
|
68
|
+
self.num_columns = 1
|
|
69
|
+
elif temp_arr.ndim == 2:
|
|
70
|
+
self._len = num_rows if num_rows is not None else temp_arr.shape[0]
|
|
71
|
+
self.num_columns = num_columns if num_columns is not None else temp_arr.shape[1]
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(f"Array ndim {temp_arr.ndim} not supported.")
|
|
74
|
+
|
|
75
|
+
print(f"Dataset '{array_uri}': Inferred shape. Rows: {self._len}, Columns: {self.num_columns}")
|
|
76
|
+
|
|
77
|
+
except Exception as e:
|
|
78
|
+
if num_rows is None or num_columns is None:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"num_rows and num_columns must be provided if inferring array shape fails for '{array_uri}'."
|
|
81
|
+
) from e
|
|
82
|
+
self._len = num_rows
|
|
83
|
+
self.feature_dim = num_columns
|
|
84
|
+
warn(
|
|
85
|
+
f"Falling back to provided or zero dimensions for '{array_uri}' due to inference error: {e}",
|
|
86
|
+
RuntimeWarning,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if self.num_columns is None or self.num_columns <= 0 and self._len > 0: # Check if num_columns is valid
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"num_columns ({self.num_columns}) is invalid or could not be determined for array '{array_uri}'."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if self._len == 0:
|
|
95
|
+
warn(f"Dataset for '{array_uri}' has length 0.", RuntimeWarning)
|
|
96
|
+
|
|
97
|
+
def _init_worker_state(self):
|
|
98
|
+
"""Initializes the DenseCellArray instance for the current worker."""
|
|
99
|
+
if self.cell_array_instance is None:
|
|
100
|
+
ctx = tiledb.Ctx(self.cellarr_ctx_config) if self.cellarr_ctx_config else None
|
|
101
|
+
self.cell_array_instance = DenseCellArray(
|
|
102
|
+
uri=self.array_uri, attr=self.attribute_name, mode="r", config_or_context=ctx
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Sanity check: worker's shape against dataset's established dims
|
|
106
|
+
# if self.cell_array_instance.shape[0] != self._len or \
|
|
107
|
+
# (self.cell_array_instance.ndim > 1 and self.cell_array_instance.shape[1] != self.feature_dim) or \
|
|
108
|
+
# (self.cell_array_instance.ndim == 1 and self.feature_dim != 1) :
|
|
109
|
+
# print(f"Warning: Worker for {self.array_uri} sees shape {self.cell_array_instance.shape} "
|
|
110
|
+
# f"but dataset initialized with len={self._len}, feat={self.feature_dim}")
|
|
111
|
+
|
|
112
|
+
def __len__(self):
|
|
113
|
+
return self._len
|
|
114
|
+
|
|
115
|
+
def __getitem__(self, idx):
|
|
116
|
+
if not 0 <= idx < self._len:
|
|
117
|
+
raise IndexError(f"Index {idx} out of bounds for dataset of length {self._len}.")
|
|
118
|
+
|
|
119
|
+
self._init_worker_state()
|
|
120
|
+
|
|
121
|
+
if self.cell_array_instance.ndim == 2:
|
|
122
|
+
item_slice = (slice(idx, idx + 1), slice(None))
|
|
123
|
+
elif self.cell_array_instance.ndim == 1:
|
|
124
|
+
item_slice = slice(idx, idx + 1)
|
|
125
|
+
else:
|
|
126
|
+
raise ValueError(f"Array ndim {self.cell_array_instance.ndim} not supported in __getitem__.")
|
|
127
|
+
|
|
128
|
+
sample_data_np = self.cell_array_instance[item_slice]
|
|
129
|
+
if sample_data_np.ndim == 2 and sample_data_np.shape[0] == 1:
|
|
130
|
+
sample_data_np = sample_data_np.squeeze(0)
|
|
131
|
+
elif sample_data_np.ndim == 1 and sample_data_np.shape[0] == 1 and self.feature_dim == 1:
|
|
132
|
+
pass
|
|
133
|
+
elif sample_data_np.ndim == 0 and self.feature_dim == 1:
|
|
134
|
+
sample_data_np = np.array([sample_data_np])
|
|
135
|
+
|
|
136
|
+
if self.transform:
|
|
137
|
+
sample_data_np = self.transform(sample_data_np)
|
|
138
|
+
|
|
139
|
+
return torch.from_numpy(sample_data_np)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def construct_dense_array_dataloader(
|
|
143
|
+
array_uri: str,
|
|
144
|
+
attribute_name: str = "data",
|
|
145
|
+
num_rows: Optional[int] = None,
|
|
146
|
+
num_columns: Optional[int] = None,
|
|
147
|
+
batch_size: int = 1000,
|
|
148
|
+
num_workers_dl: int = 2,
|
|
149
|
+
) -> DataLoader:
|
|
150
|
+
"""Construct an instance of `DenseArrayDataset` with PyTorch DataLoader.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
array_uri:
|
|
154
|
+
URI of the TileDB array.
|
|
155
|
+
|
|
156
|
+
attribute_name:
|
|
157
|
+
Name of the attribute to read from.
|
|
158
|
+
|
|
159
|
+
num_rows:
|
|
160
|
+
The total number of rows in the TileDB array.
|
|
161
|
+
|
|
162
|
+
num_columns:
|
|
163
|
+
The total number of columns in the TileDB array.
|
|
164
|
+
|
|
165
|
+
batch_size:
|
|
166
|
+
Number of random samples per batch generated by the dataset.
|
|
167
|
+
|
|
168
|
+
num_workers_dl:
|
|
169
|
+
Number of worker processes for the DataLoader.
|
|
170
|
+
"""
|
|
171
|
+
tiledb_ctx_config = {
|
|
172
|
+
"sm.tile_cache_size": 1000 * 1024**2, # 1000 MB tile cache per worker
|
|
173
|
+
"sm.num_reader_threads": 4,
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
dataset = DenseArrayDataset(
|
|
177
|
+
array_uri=array_uri,
|
|
178
|
+
attribute_name=attribute_name,
|
|
179
|
+
num_rows=num_rows,
|
|
180
|
+
num_columns=num_columns,
|
|
181
|
+
cellarr_ctx_config=tiledb_ctx_config,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if len(dataset) == 0:
|
|
185
|
+
print("Dataset is empty, cannot create DataLoader.")
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
dataloader = DataLoader(
|
|
189
|
+
dataset,
|
|
190
|
+
batch_size=batch_size,
|
|
191
|
+
shuffle=True,
|
|
192
|
+
num_workers=num_workers_dl,
|
|
193
|
+
pin_memory=True,
|
|
194
|
+
prefetch_factor=2,
|
|
195
|
+
persistent_workers=True if num_workers_dl > 0 else False,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
return dataloader
|