cellarr-array 0.0.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cellarr-array might be problematic. Click here for more details.
- cellarr_array/__init__.py +2 -4
- cellarr_array/core/__init__.py +3 -0
- cellarr_array/core/base.py +344 -0
- cellarr_array/{DenseCellArray.py → core/dense.py} +2 -3
- cellarr_array/{helpers.py → core/helpers.py} +80 -42
- cellarr_array/{SparseCellArray.py → core/sparse.py} +75 -27
- cellarr_array/dataloaders/__init__.py +3 -0
- cellarr_array/dataloaders/denseloader.py +198 -0
- cellarr_array/dataloaders/iterabledataloader.py +320 -0
- cellarr_array/dataloaders/sparseloader.py +230 -0
- cellarr_array/dataloaders/utils.py +26 -0
- cellarr_array/utils/__init__.py +3 -0
- cellarr_array/utils/mock.py +167 -0
- {cellarr_array-0.0.3.dist-info → cellarr_array-0.2.0.dist-info}/METADATA +4 -1
- cellarr_array-0.2.0.dist-info/RECORD +19 -0
- {cellarr_array-0.0.3.dist-info → cellarr_array-0.2.0.dist-info}/WHEEL +1 -1
- {cellarr_array-0.0.3.dist-info → cellarr_array-0.2.0.dist-info}/licenses/LICENSE.txt +1 -1
- cellarr_array/CellArray.py +0 -251
- cellarr_array-0.0.3.dist-info/RECORD +0 -11
- /cellarr_array/{config.py → utils/config.py} +0 -0
- {cellarr_array-0.0.3.dist-info → cellarr_array-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -3,14 +3,14 @@ try:
|
|
|
3
3
|
except ImportError:
|
|
4
4
|
# TODO: This is required for Python <3.10. Remove once Python 3.9 reaches EOL in October 2025
|
|
5
5
|
EllipsisType = type(...)
|
|
6
|
-
from typing import Dict, List, Optional, Tuple, Union
|
|
6
|
+
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tiledb
|
|
10
10
|
from scipy import sparse
|
|
11
11
|
|
|
12
|
-
from .CellArray import CellArray
|
|
13
12
|
from .helpers import SliceHelper
|
|
13
|
+
from .base import CellArray
|
|
14
14
|
|
|
15
15
|
__author__ = "Jayaram Kancherla"
|
|
16
16
|
__copyright__ = "Jayaram Kancherla"
|
|
@@ -22,18 +22,71 @@ class SparseCellArray(CellArray):
|
|
|
22
22
|
|
|
23
23
|
def __init__(
|
|
24
24
|
self,
|
|
25
|
-
uri: str,
|
|
25
|
+
uri: Optional[str] = None,
|
|
26
|
+
tiledb_array_obj: Optional[tiledb.Array] = None,
|
|
26
27
|
attr: str = "data",
|
|
27
|
-
mode:
|
|
28
|
+
mode: Optional[Literal["r", "w", "d", "m"]] = None,
|
|
28
29
|
config_or_context: Optional[Union[tiledb.Config, tiledb.Ctx]] = None,
|
|
29
30
|
return_sparse: bool = True,
|
|
30
|
-
|
|
31
|
+
sparse_format: Union[sparse.csr_matrix, sparse.csc_matrix] = sparse.csr_matrix,
|
|
32
|
+
validate: bool = True,
|
|
33
|
+
**kwargs,
|
|
31
34
|
):
|
|
32
|
-
"""Initialize
|
|
33
|
-
|
|
35
|
+
"""Initialize the object.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
uri:
|
|
39
|
+
URI to the array.
|
|
40
|
+
Required if 'tiledb_array_obj' is not provided.
|
|
41
|
+
|
|
42
|
+
tiledb_array_obj:
|
|
43
|
+
Optional, an already opened ``tiledb.Array`` instance.
|
|
44
|
+
If provided, 'uri' can be None, and 'config_or_context' is ignored.
|
|
45
|
+
|
|
46
|
+
attr:
|
|
47
|
+
Attribute to access.
|
|
48
|
+
Defaults to "data".
|
|
49
|
+
|
|
50
|
+
mode:
|
|
51
|
+
Open the array object in read 'r', write 'w', modify
|
|
52
|
+
'm' mode, or delete 'd' mode.
|
|
53
|
+
|
|
54
|
+
Defaults to None for automatic mode switching.
|
|
55
|
+
|
|
56
|
+
If 'tiledb_array_obj' is provided, this mode should ideally match
|
|
57
|
+
the mode of the provided array or be None.
|
|
58
|
+
|
|
59
|
+
config_or_context:
|
|
60
|
+
Optional config or context object. Ignored if 'tiledb_array_obj' is provided,
|
|
61
|
+
as context will be derived from the object.
|
|
62
|
+
|
|
63
|
+
Defaults to None.
|
|
64
|
+
|
|
65
|
+
return_sparse:
|
|
66
|
+
Whether to return a sparse representation of the data when object is sliced.
|
|
67
|
+
Default is to return a dictionary that contains coordinates and values.
|
|
68
|
+
|
|
69
|
+
sparse_format:
|
|
70
|
+
Format to return, defaults to csr_matrix.
|
|
71
|
+
|
|
72
|
+
validate:
|
|
73
|
+
Whether to validate the attributes.
|
|
74
|
+
Defaults to True.
|
|
75
|
+
|
|
76
|
+
kwargs:
|
|
77
|
+
Additional arguments.
|
|
78
|
+
"""
|
|
79
|
+
super().__init__(
|
|
80
|
+
uri=uri,
|
|
81
|
+
tiledb_array_obj=tiledb_array_obj,
|
|
82
|
+
attr=attr,
|
|
83
|
+
mode=mode,
|
|
84
|
+
config_or_context=config_or_context,
|
|
85
|
+
validate=validate,
|
|
86
|
+
)
|
|
34
87
|
|
|
35
88
|
self.return_sparse = return_sparse
|
|
36
|
-
self.
|
|
89
|
+
self.sparse_format = sparse.csr_matrix if sparse_format is None else sparse_format
|
|
37
90
|
|
|
38
91
|
def _validate_matrix_dims(self, data: sparse.spmatrix) -> Tuple[sparse.coo_matrix, bool]:
|
|
39
92
|
"""Validate and adjust matrix dimensions if needed.
|
|
@@ -73,7 +126,7 @@ class SparseCellArray(CellArray):
|
|
|
73
126
|
shape.append(idx.stop - (idx.start or 0))
|
|
74
127
|
elif isinstance(idx, list):
|
|
75
128
|
shape.append(len(set(idx)))
|
|
76
|
-
else:
|
|
129
|
+
else:
|
|
77
130
|
shape.append(1)
|
|
78
131
|
|
|
79
132
|
# Always return (n,1) shape for CSR matrix
|
|
@@ -87,20 +140,17 @@ class SparseCellArray(CellArray):
|
|
|
87
140
|
"""Convert TileDB result to CSR format or dense array."""
|
|
88
141
|
data = result[self._attr]
|
|
89
142
|
|
|
90
|
-
# empty result
|
|
91
143
|
if len(data) == 0:
|
|
92
144
|
print("is emoty")
|
|
93
145
|
if not self.return_sparse:
|
|
94
146
|
return result
|
|
95
147
|
else:
|
|
96
|
-
# For COO output, return empty sparse matrix
|
|
97
148
|
if self.ndim == 1:
|
|
98
|
-
matrix = self.
|
|
149
|
+
matrix = self.sparse_format((1, shape[0]))
|
|
99
150
|
return matrix[:, key[0]]
|
|
100
151
|
|
|
101
|
-
return self.
|
|
152
|
+
return self.sparse_format(shape)[key]
|
|
102
153
|
|
|
103
|
-
# Get coordinates
|
|
104
154
|
coords = []
|
|
105
155
|
for dim_name in self.dim_names:
|
|
106
156
|
dim_coords = result[dim_name]
|
|
@@ -111,11 +161,12 @@ class SparseCellArray(CellArray):
|
|
|
111
161
|
coords = [np.zeros_like(coords[0]), coords[0]]
|
|
112
162
|
shape = (1, shape[0])
|
|
113
163
|
|
|
114
|
-
# Create sparse matrix
|
|
115
164
|
matrix = sparse.coo_matrix((data, tuple(coords)), shape=shape)
|
|
116
|
-
|
|
165
|
+
|
|
166
|
+
sliced = matrix
|
|
167
|
+
if self.sparse_format in (sparse.csr_matrix, sparse.csr_array):
|
|
117
168
|
sliced = matrix.tocsr()
|
|
118
|
-
elif self.
|
|
169
|
+
elif self.sparse_format in (sparse.csc_matrix, sparse.csc_array):
|
|
119
170
|
sliced = matrix.tocsc()
|
|
120
171
|
|
|
121
172
|
if self.ndim == 1:
|
|
@@ -147,7 +198,6 @@ class SparseCellArray(CellArray):
|
|
|
147
198
|
if all(isinstance(idx, slice) for idx in optimized_key):
|
|
148
199
|
return self._direct_slice(tuple(optimized_key))
|
|
149
200
|
|
|
150
|
-
# For mixed slice-list queries, adjust slice bounds
|
|
151
201
|
tiledb_key = []
|
|
152
202
|
for idx in key:
|
|
153
203
|
if isinstance(idx, slice):
|
|
@@ -186,22 +236,20 @@ class SparseCellArray(CellArray):
|
|
|
186
236
|
if not sparse.issparse(data):
|
|
187
237
|
raise TypeError("Input must be a scipy sparse matrix.")
|
|
188
238
|
|
|
189
|
-
|
|
190
|
-
data, is_1d = self._validate_matrix_dims(data)
|
|
239
|
+
coo_data, is_1d = self._validate_matrix_dims(data)
|
|
191
240
|
|
|
192
|
-
|
|
193
|
-
end_row = start_row + data.shape[0]
|
|
241
|
+
end_row = start_row + coo_data.shape[0]
|
|
194
242
|
if end_row > self.shape[0]:
|
|
195
243
|
raise ValueError(
|
|
196
244
|
f"Write operation would exceed array bounds. End row {end_row} > array rows {self.shape[0]}."
|
|
197
245
|
)
|
|
198
246
|
|
|
199
|
-
if not is_1d and
|
|
200
|
-
raise ValueError(f"Data columns {
|
|
247
|
+
if not is_1d and coo_data.shape[1] != self.shape[1]:
|
|
248
|
+
raise ValueError(f"Data columns {coo_data.shape[1]} don't match array columns {self.shape[1]}.")
|
|
201
249
|
|
|
202
|
-
adjusted_rows =
|
|
250
|
+
adjusted_rows = coo_data.row + start_row
|
|
203
251
|
with self.open_array(mode="w") as array:
|
|
204
252
|
if is_1d:
|
|
205
|
-
array[adjusted_rows] =
|
|
253
|
+
array[adjusted_rows] = coo_data.data
|
|
206
254
|
else:
|
|
207
|
-
array[adjusted_rows,
|
|
255
|
+
array[adjusted_rows, coo_data.col] = coo_data.data
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from warnings import warn
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import tiledb
|
|
6
|
+
import torch
|
|
7
|
+
from torch.utils.data import DataLoader, Dataset
|
|
8
|
+
|
|
9
|
+
from ..core.dense import DenseCellArray
|
|
10
|
+
|
|
11
|
+
__author__ = "Jayaram Kancherla"
|
|
12
|
+
__copyright__ = "Jayaram Kancherla"
|
|
13
|
+
__license__ = "MIT"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DenseArrayDataset(Dataset):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
array_uri: str,
|
|
20
|
+
attribute_name: str = "data",
|
|
21
|
+
num_rows: Optional[int] = None,
|
|
22
|
+
num_columns: Optional[int] = None,
|
|
23
|
+
cellarr_ctx_config: Optional[dict] = None,
|
|
24
|
+
transform=None,
|
|
25
|
+
):
|
|
26
|
+
"""PyTorch Dataset for dense TileDB arrays accessed via DenseCellArray.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
array_uri:
|
|
30
|
+
URI of the TileDB dense array.
|
|
31
|
+
|
|
32
|
+
attribute_name:
|
|
33
|
+
Name of the attribute to read from.
|
|
34
|
+
|
|
35
|
+
num_rows:
|
|
36
|
+
Total number of rows in the dataset.
|
|
37
|
+
If None, will infer from `array.shape[0]`.
|
|
38
|
+
|
|
39
|
+
num_columns:
|
|
40
|
+
The number of columns in the dataset.
|
|
41
|
+
If None, will attempt to infer `from array.shape[1]`.
|
|
42
|
+
|
|
43
|
+
cellarr_ctx_config:
|
|
44
|
+
Optional TileDB context configuration dict for CellArray.
|
|
45
|
+
|
|
46
|
+
transform:
|
|
47
|
+
Optional transform to be applied on a sample.
|
|
48
|
+
"""
|
|
49
|
+
self.array_uri = array_uri
|
|
50
|
+
self.attribute_name = attribute_name
|
|
51
|
+
self.cellarr_ctx_config = cellarr_ctx_config
|
|
52
|
+
self.transform = transform
|
|
53
|
+
self.cell_array_instance = None
|
|
54
|
+
|
|
55
|
+
if num_rows is not None and num_columns is not None:
|
|
56
|
+
self._len = num_rows
|
|
57
|
+
self.num_columns = num_columns
|
|
58
|
+
else:
|
|
59
|
+
# Infer the array shape
|
|
60
|
+
print(f"Dataset '{array_uri}': num_rows or num_columns not provided. Probing array...")
|
|
61
|
+
init_ctx_config = tiledb.Config(self.cellarr_ctx_config) if self.cellarr_ctx_config else None
|
|
62
|
+
try:
|
|
63
|
+
temp_arr = DenseCellArray(
|
|
64
|
+
uri=self.array_uri, attr=self.attribute_name, config_or_context=init_ctx_config
|
|
65
|
+
)
|
|
66
|
+
if temp_arr.ndim == 1:
|
|
67
|
+
self._len = num_rows if num_rows is not None else temp_arr.shape[0]
|
|
68
|
+
self.num_columns = 1
|
|
69
|
+
elif temp_arr.ndim == 2:
|
|
70
|
+
self._len = num_rows if num_rows is not None else temp_arr.shape[0]
|
|
71
|
+
self.num_columns = num_columns if num_columns is not None else temp_arr.shape[1]
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(f"Array ndim {temp_arr.ndim} not supported.")
|
|
74
|
+
|
|
75
|
+
print(f"Dataset '{array_uri}': Inferred shape. Rows: {self._len}, Columns: {self.num_columns}")
|
|
76
|
+
|
|
77
|
+
except Exception as e:
|
|
78
|
+
if num_rows is None or num_columns is None:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"num_rows and num_columns must be provided if inferring array shape fails for '{array_uri}'."
|
|
81
|
+
) from e
|
|
82
|
+
self._len = num_rows
|
|
83
|
+
self.feature_dim = num_columns
|
|
84
|
+
warn(
|
|
85
|
+
f"Falling back to provided or zero dimensions for '{array_uri}' due to inference error: {e}",
|
|
86
|
+
RuntimeWarning,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if self.num_columns is None or self.num_columns <= 0 and self._len > 0: # Check if num_columns is valid
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"num_columns ({self.num_columns}) is invalid or could not be determined for array '{array_uri}'."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if self._len == 0:
|
|
95
|
+
warn(f"Dataset for '{array_uri}' has length 0.", RuntimeWarning)
|
|
96
|
+
|
|
97
|
+
def _init_worker_state(self):
|
|
98
|
+
"""Initializes the DenseCellArray instance for the current worker."""
|
|
99
|
+
if self.cell_array_instance is None:
|
|
100
|
+
ctx = tiledb.Ctx(self.cellarr_ctx_config) if self.cellarr_ctx_config else None
|
|
101
|
+
self.cell_array_instance = DenseCellArray(
|
|
102
|
+
uri=self.array_uri, attr=self.attribute_name, mode="r", config_or_context=ctx
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Sanity check: worker's shape against dataset's established dims
|
|
106
|
+
# if self.cell_array_instance.shape[0] != self._len or \
|
|
107
|
+
# (self.cell_array_instance.ndim > 1 and self.cell_array_instance.shape[1] != self.feature_dim) or \
|
|
108
|
+
# (self.cell_array_instance.ndim == 1 and self.feature_dim != 1) :
|
|
109
|
+
# print(f"Warning: Worker for {self.array_uri} sees shape {self.cell_array_instance.shape} "
|
|
110
|
+
# f"but dataset initialized with len={self._len}, feat={self.feature_dim}")
|
|
111
|
+
|
|
112
|
+
def __len__(self):
|
|
113
|
+
return self._len
|
|
114
|
+
|
|
115
|
+
def __getitem__(self, idx):
|
|
116
|
+
if not 0 <= idx < self._len:
|
|
117
|
+
raise IndexError(f"Index {idx} out of bounds for dataset of length {self._len}.")
|
|
118
|
+
|
|
119
|
+
self._init_worker_state()
|
|
120
|
+
|
|
121
|
+
if self.cell_array_instance.ndim == 2:
|
|
122
|
+
item_slice = (slice(idx, idx + 1), slice(None))
|
|
123
|
+
elif self.cell_array_instance.ndim == 1:
|
|
124
|
+
item_slice = slice(idx, idx + 1)
|
|
125
|
+
else:
|
|
126
|
+
raise ValueError(f"Array ndim {self.cell_array_instance.ndim} not supported in __getitem__.")
|
|
127
|
+
|
|
128
|
+
sample_data_np = self.cell_array_instance[item_slice]
|
|
129
|
+
if sample_data_np.ndim == 2 and sample_data_np.shape[0] == 1:
|
|
130
|
+
sample_data_np = sample_data_np.squeeze(0)
|
|
131
|
+
elif sample_data_np.ndim == 1 and sample_data_np.shape[0] == 1 and self.feature_dim == 1:
|
|
132
|
+
pass
|
|
133
|
+
elif sample_data_np.ndim == 0 and self.feature_dim == 1:
|
|
134
|
+
sample_data_np = np.array([sample_data_np])
|
|
135
|
+
|
|
136
|
+
if self.transform:
|
|
137
|
+
sample_data_np = self.transform(sample_data_np)
|
|
138
|
+
|
|
139
|
+
return torch.from_numpy(sample_data_np)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def construct_dense_array_dataloader(
|
|
143
|
+
array_uri: str,
|
|
144
|
+
attribute_name: str = "data",
|
|
145
|
+
num_rows: Optional[int] = None,
|
|
146
|
+
num_columns: Optional[int] = None,
|
|
147
|
+
batch_size: int = 1000,
|
|
148
|
+
num_workers_dl: int = 2,
|
|
149
|
+
) -> DataLoader:
|
|
150
|
+
"""Construct an instance of `DenseArrayDataset` with PyTorch DataLoader.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
array_uri:
|
|
154
|
+
URI of the TileDB array.
|
|
155
|
+
|
|
156
|
+
attribute_name:
|
|
157
|
+
Name of the attribute to read from.
|
|
158
|
+
|
|
159
|
+
num_rows:
|
|
160
|
+
The total number of rows in the TileDB array.
|
|
161
|
+
|
|
162
|
+
num_columns:
|
|
163
|
+
The total number of columns in the TileDB array.
|
|
164
|
+
|
|
165
|
+
batch_size:
|
|
166
|
+
Number of random samples per batch generated by the dataset.
|
|
167
|
+
|
|
168
|
+
num_workers_dl:
|
|
169
|
+
Number of worker processes for the DataLoader.
|
|
170
|
+
"""
|
|
171
|
+
tiledb_ctx_config = {
|
|
172
|
+
"sm.tile_cache_size": 1000 * 1024**2, # 1000 MB tile cache per worker
|
|
173
|
+
"sm.num_reader_threads": 4,
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
dataset = DenseArrayDataset(
|
|
177
|
+
array_uri=array_uri,
|
|
178
|
+
attribute_name=attribute_name,
|
|
179
|
+
num_rows=num_rows,
|
|
180
|
+
num_columns=num_columns,
|
|
181
|
+
cellarr_ctx_config=tiledb_ctx_config,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if len(dataset) == 0:
|
|
185
|
+
print("Dataset is empty, cannot create DataLoader.")
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
dataloader = DataLoader(
|
|
189
|
+
dataset,
|
|
190
|
+
batch_size=batch_size,
|
|
191
|
+
shuffle=True,
|
|
192
|
+
num_workers=num_workers_dl,
|
|
193
|
+
pin_memory=True,
|
|
194
|
+
prefetch_factor=2,
|
|
195
|
+
persistent_workers=True if num_workers_dl > 0 else False,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
return dataloader
|