cellarr-array 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cellarr-array might be problematic. Click here for more details.

@@ -0,0 +1,320 @@
1
+ from typing import Callable, Dict, Iterator, Optional, Union
2
+
3
+ import numpy as np
4
+ import scipy.sparse as sp
5
+ import tiledb
6
+ import torch
7
+ from torch.utils.data import DataLoader, IterableDataset
8
+
9
+ from .utils import seed_worker
10
+
11
+ __author__ = "Jayaram Kancherla"
12
+ __copyright__ = "Jayaram Kancherla"
13
+ __license__ = "MIT"
14
+
15
+
16
+ class CellArrayIterableDataset(IterableDataset):
17
+ """A PyTorch IterableDataset that yields batches of randomly sampled rows
18
+ from a TileDB array (dense or sparse) using cellarr-array.
19
+
20
+ An `IterableDataset` dataset is responsible for yielding entire batches of data,
21
+ giving us full control over how a batch is formed, including
22
+ performing a single bulk read from TileDB.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ array_uri: str,
28
+ attribute_name: str,
29
+ num_rows: int,
30
+ num_columns: int,
31
+ is_sparse: bool,
32
+ batch_size: int = 1000,
33
+ num_yields_per_epoch_per_worker: Optional[int] = None,
34
+ cellarr_ctx_config: Optional[Dict] = None,
35
+ transform: Optional[Callable] = None,
36
+ ):
37
+ """Initializes the `CellArrayIterableDataset`.
38
+
39
+ Args:
40
+ array_uri:
41
+ URI of the TileDB array.
42
+
43
+ attribute_name:
44
+ Name of the TileDB attribute to read.
45
+
46
+ num_rows:
47
+ The total number of rows in the TileDB array.
48
+
49
+ num_columns:
50
+ The total number of columns in the TileDB array.
51
+
52
+ is_sparse:
53
+ True if the TileDB array is sparse, False if dense.
54
+
55
+ batch_size:
56
+ The number of random samples to include in each yielded batch. Defaults to 1000.
57
+
58
+ num_yields_per_epoch_per_worker:
59
+ The number of batches this dataset's iterator (per worker) will yield in one epoch.
60
+ If None, it defaults to roughly covering all samples once across all workers (approx).
61
+ The total batches seen by the training loop will be `num_workers * num_yields_per_epoch_per_worker`.
62
+ Defaults to None.
63
+
64
+ cellarr_ctx_config:
65
+ Configuration dictionary for the TileDB context used by cellarr-array. Defaults to None.
66
+
67
+ transform:
68
+ A function/transform that takes the entire fetched batch (NumPy array for dense,
69
+ SciPy sparse matrix for sparse) and returns a transformed version. Defaults to None.
70
+ """
71
+ super().__init__()
72
+
73
+ if not isinstance(batch_size, int) or batch_size <= 0:
74
+ raise ValueError("batch_size must be a positive integer.")
75
+
76
+ if not isinstance(num_rows, int) or num_rows < 0:
77
+ raise ValueError("num_rows must be a non-negative integer.")
78
+
79
+ if not isinstance(num_columns, int) or num_columns <= 0:
80
+ raise ValueError("num_columns must be a positive integer.")
81
+
82
+ self.array_uri = array_uri
83
+ self.attribute_name = attribute_name
84
+ self.num_rows = num_rows
85
+ self.num_columns = num_columns
86
+ self.is_sparse = is_sparse
87
+ self.batch_size = batch_size
88
+ self.cellarr_ctx_config = cellarr_ctx_config
89
+ self.transform = transform
90
+
91
+ if num_yields_per_epoch_per_worker is None:
92
+ # Default to roughly one pass over the data across all workers
93
+ self.num_yields_per_epoch_per_worker = self.num_rows // self.batch_size
94
+ if self.num_yields_per_epoch_per_worker == 0 and self.num_rows > 0:
95
+ self.num_yields_per_epoch_per_worker = 1
96
+ else:
97
+ self.num_yields_per_epoch_per_worker = num_yields_per_epoch_per_worker
98
+
99
+ self.cell_array_instance = None
100
+
101
+ def _init_worker_state(self) -> None:
102
+ """Initializes the cellarr-array instance (DenseCellArray or SparseCellArray)
103
+ for the current worker process.
104
+
105
+ This makes sure that each worker has its own TileDB array object.
106
+ """
107
+ if self.cell_array_instance is None:
108
+ worker_info = torch.utils.data.get_worker_info()
109
+ worker_id = worker_info.id if worker_info else "Main"
110
+ print(f"Worker {worker_id}: Initializing CellArray instance.")
111
+
112
+ ctx = tiledb.Ctx(self.cellarr_ctx_config) if self.cellarr_ctx_config else None
113
+ if self.is_sparse:
114
+ from cellarr_array import SparseCellArray
115
+
116
+ self.cell_array_instance = SparseCellArray(
117
+ uri=self.array_uri,
118
+ attr=self.attribute_name,
119
+ mode="r",
120
+ config_or_context=ctx,
121
+ return_sparse=True,
122
+ sparse_coerce=sp.coo_matrix,
123
+ )
124
+ else:
125
+ from cellarr_array import DenseCellArray
126
+
127
+ self.cell_array_instance = DenseCellArray(
128
+ uri=self.array_uri, attr=self.attribute_name, mode="r", config_or_context=ctx
129
+ )
130
+
131
+ def _fetch_one_random_batch(self) -> Union[np.ndarray, sp.spmatrix]:
132
+ """Randomly selects `self.batch_size` row indices and fetches them from
133
+ the TileDB array in a single multi-index read operation.
134
+
135
+ Returns:
136
+ A NumPy array (for dense) or SciPy sparse matrix (for sparse)
137
+ representing the fetched batch of data. The shape will be
138
+ (N, self.num_columns), where N is the number of successfully fetched rows
139
+ (usually self.batch_size, but could be less if num_rows < self.batch_size).
140
+ """
141
+ if self.num_rows == 0:
142
+ if self.is_sparse:
143
+ return sp.coo_matrix((0, self.num_columns), dtype=np.float32)
144
+ else:
145
+ return np.empty((0, self.num_columns), dtype=np.float32)
146
+
147
+ actual_batch_size = min(self.batch_size, self.num_rows)
148
+ if actual_batch_size == 0:
149
+ if self.is_sparse:
150
+ return sp.coo_matrix((0, self.num_columns), dtype=np.float32)
151
+ else:
152
+ return np.empty((0, self.num_columns), dtype=np.float32)
153
+
154
+ random_indices = np.random.choice(
155
+ self.num_rows,
156
+ size=actual_batch_size,
157
+ replace=False,
158
+ )
159
+
160
+ random_indices.sort()
161
+ batch_slice_key = (list(random_indices), slice(None))
162
+ data_chunk = self.cell_array_instance[batch_slice_key]
163
+
164
+ return data_chunk
165
+
166
+ def __iter__(self) -> Iterator[Union[np.ndarray, sp.spmatrix]]:
167
+ """Yields batches of randomly sampled data.
168
+
169
+ This method is called by the DataLoader for each worker.
170
+ """
171
+ self._init_worker_state()
172
+
173
+ for _ in range(self.num_yields_per_epoch_per_worker):
174
+ if self.num_rows == 0:
175
+ if self.is_sparse:
176
+ yield sp.coo_matrix((0, self.num_columns), dtype=np.float32)
177
+ else:
178
+ yield np.empty((0, self.num_columns), dtype=np.float32)
179
+
180
+ break
181
+
182
+ batch_data = self._fetch_one_random_batch()
183
+
184
+ if self.transform:
185
+ batch_data = self.transform(batch_data)
186
+ yield batch_data
187
+
188
+
189
+ def dense_batch_collate_fn(numpy_batch: np.ndarray) -> torch.Tensor:
190
+ """Collate function for a dense batch from CellArrayIterableDataset.
191
+
192
+ Receives the numpy_batch directly from the dataset's iterator.
193
+ """
194
+ if numpy_batch is None or (hasattr(numpy_batch, "shape") and numpy_batch.shape[0] == 0):
195
+ print("CollateFn (Dense): Received batch_item that is None or has 0 rows.")
196
+ if numpy_batch is not None:
197
+ return torch.from_numpy(numpy_batch)
198
+ return torch.empty(0)
199
+
200
+ return torch.from_numpy(numpy_batch)
201
+
202
+
203
+ def sparse_batch_collate_fn(scipy_sparse_batch: sp.spmatrix) -> torch.Tensor:
204
+ """Collate function for a sparse batch from CellArrayIterableDataset.
205
+
206
+ Receives the scipy_sparse_batch directly from the dataset's iterator.
207
+ """
208
+ if scipy_sparse_batch is None or (hasattr(scipy_sparse_batch, "shape") and scipy_sparse_batch.shape[0] == 0):
209
+ print("CollateFn (Sparse): Received batch_item that is None or has 0 rows.")
210
+ num_cols = 0
211
+ dtype_to_use = torch.float32
212
+ if scipy_sparse_batch is not None:
213
+ num_cols = scipy_sparse_batch.shape[1]
214
+ try:
215
+ dtype_to_use = torch.from_numpy(scipy_sparse_batch.data[:0]).dtype
216
+ except (AttributeError, TypeError):
217
+ try:
218
+ if hasattr(scipy_sparse_batch, "dtype"):
219
+ if scipy_sparse_batch.dtype == np.float32:
220
+ dtype_to_use = torch.float32
221
+ elif scipy_sparse_batch.dtype == np.float64:
222
+ dtype_to_use = torch.float64
223
+ elif scipy_sparse_batch.dtype == np.int32:
224
+ dtype_to_use = torch.int32
225
+ except Exception:
226
+ pass
227
+
228
+ return torch.sparse_coo_tensor(
229
+ torch.empty((2, 0), dtype=torch.long),
230
+ torch.empty(0, dtype=dtype_to_use),
231
+ (0, num_cols),
232
+ )
233
+
234
+ if not isinstance(scipy_sparse_batch, sp.coo_matrix):
235
+ scipy_sparse_batch = scipy_sparse_batch.tocoo()
236
+
237
+ if scipy_sparse_batch.nnz == 0:
238
+ return torch.sparse_coo_tensor(
239
+ torch.empty((2, 0), dtype=torch.long),
240
+ torch.empty(0, dtype=torch.from_numpy(scipy_sparse_batch.data[:0]).dtype),
241
+ torch.Size(scipy_sparse_batch.shape),
242
+ )
243
+
244
+ values = torch.from_numpy(scipy_sparse_batch.data)
245
+ indices = torch.from_numpy(np.vstack((scipy_sparse_batch.row, scipy_sparse_batch.col))).long()
246
+ sparse_shape = torch.Size(scipy_sparse_batch.shape)
247
+
248
+ return torch.sparse_coo_tensor(indices, values, sparse_shape)
249
+
250
+
251
+ def construct_iterable_dataloader(
252
+ array_uri: str,
253
+ is_sparse: bool,
254
+ attribute_name: str = "data",
255
+ num_rows: int = None,
256
+ num_columns: int = None,
257
+ batch_size: int = 1000,
258
+ num_workers_dl: int = 2,
259
+ num_yields_per_worker: int = 5,
260
+ ) -> DataLoader:
261
+ """Construct an instance of `CellArrayIterableDataset` with PyTorch DataLoader.
262
+
263
+ Args:
264
+ array_uri:
265
+ URI of the TileDB array.
266
+
267
+ attribute_name:
268
+ Name of the attribute to read from.
269
+
270
+ num_rows:
271
+ The total number of rows in the TileDB array.
272
+
273
+ num_columns:
274
+ The total number of columns in the TileDB array.
275
+
276
+ is_sparse:
277
+ True if the array is sparse, False for dense.
278
+
279
+ batch_size:
280
+ Number of random samples per batch generated by the dataset.
281
+
282
+ num_workers_dl:
283
+ Number of worker processes for the DataLoader.
284
+
285
+ num_yields_per_worker:
286
+ Number of batches each worker should yield per epoch.
287
+ """
288
+ tiledb_ctx_config = {
289
+ "sm.tile_cache_size": 2000 * 1024**2, # 2000MB tile cache
290
+ "sm.num_reader_threads": 4,
291
+ }
292
+
293
+ if num_rows is None or num_columns is None:
294
+ raise ValueError("num_rows and num_columns must be provided for CellArrayIterableDataset.")
295
+
296
+ dataset = CellArrayIterableDataset(
297
+ array_uri=array_uri,
298
+ attribute_name=attribute_name,
299
+ num_rows=num_rows,
300
+ num_columns=num_columns,
301
+ is_sparse=is_sparse,
302
+ batch_size=batch_size,
303
+ num_yields_per_epoch_per_worker=num_yields_per_worker,
304
+ cellarr_ctx_config=tiledb_ctx_config,
305
+ )
306
+
307
+ collate_to_use = sparse_batch_collate_fn if is_sparse else dense_batch_collate_fn
308
+
309
+ dataloader = DataLoader(
310
+ dataset,
311
+ batch_size=None,
312
+ num_workers=num_workers_dl,
313
+ worker_init_fn=seed_worker,
314
+ collate_fn=collate_to_use,
315
+ pin_memory=not is_sparse and torch.cuda.is_available() and num_workers_dl > 0,
316
+ persistent_workers=True if num_workers_dl > 0 else False,
317
+ prefetch_factor=2 if num_workers_dl > 0 else None,
318
+ )
319
+
320
+ return dataloader
@@ -0,0 +1,230 @@
1
+ from typing import Optional
2
+ from warnings import warn
3
+
4
+ import scipy.sparse as sp
5
+ import tiledb
6
+ import torch
7
+ from torch.utils.data import DataLoader, Dataset
8
+
9
+ from ..core.sparse import SparseCellArray
10
+
11
+ __author__ = "Jayaram Kancherla"
12
+ __copyright__ = "Jayaram Kancherla"
13
+ __license__ = "MIT"
14
+
15
+
16
+ class SparseArrayDataset(Dataset):
17
+ def __init__(
18
+ self,
19
+ array_uri: str,
20
+ attribute_name: str = "data",
21
+ num_rows: Optional[int] = None,
22
+ num_columns: Optional[int] = None,
23
+ sparse_format=sp.csr_matrix,
24
+ cellarr_ctx_config: Optional[dict] = None,
25
+ transform=None,
26
+ ):
27
+ """PyTorch Dataset for sparse TileDB arrays accessed via SparseCellArray.
28
+
29
+ Args:
30
+ array_uri:
31
+ URI of the TileDB sparse array.
32
+
33
+ attribute_name:
34
+ Name of the attribute to read from.
35
+
36
+ num_rows:
37
+ Total number of rows in the dataset.
38
+ If None, will infer from `array.shape[0]`.
39
+
40
+ num_columns:
41
+ The number of columns in the dataset.
42
+ If None, will attempt to infer `from array.shape[1]`.
43
+
44
+ sparse_format:
45
+ Format to return, defaults to csr_matrix.
46
+
47
+ cellarr_ctx_config:
48
+ Optional TileDB context configuration dict for CellArray.
49
+
50
+ transform:
51
+ Optional transform to be applied on a sample.
52
+ """
53
+ self.array_uri = array_uri
54
+ self.attribute_name = attribute_name
55
+ self.sparse_format = sparse_format
56
+ self.cellarr_ctx_config = cellarr_ctx_config
57
+ self.transform = transform
58
+ self.cell_array_instance = None
59
+
60
+ if num_rows is not None and num_columns is not None:
61
+ self._len = num_rows
62
+ self.num_columns = num_columns
63
+ else:
64
+ print(f"Dataset '{array_uri}': num_rows or num_columns not provided. Probing sparse array...")
65
+ init_ctx_config = tiledb.Config(self.cellarr_ctx_config) if self.cellarr_ctx_config else None
66
+ try:
67
+ temp_arr = SparseCellArray(
68
+ uri=self.array_uri,
69
+ attr=self.attribute_name,
70
+ config_or_context=init_ctx_config,
71
+ return_sparse=True,
72
+ sparse_format=self.sparse_format,
73
+ )
74
+
75
+ if temp_arr.ndim == 1:
76
+ self._len = num_rows if num_rows is not None else temp_arr.shape[0]
77
+ self.num_columns = 1
78
+ elif temp_arr.ndim == 2:
79
+ self._len = num_rows if num_rows is not None else temp_arr.shape[0]
80
+ self.num_columns = num_columns if num_columns is not None else temp_arr.shape[1]
81
+ else:
82
+ raise ValueError(f"Array ndim {temp_arr.ndim} not supported.")
83
+
84
+ print(f"Dataset '{array_uri}': Inferred sparse shape. Rows: {self._len}, Columns: {self.num_columns}")
85
+
86
+ except Exception as e:
87
+ if num_rows is None or num_columns is None:
88
+ raise ValueError(
89
+ f"num_rows and num_columns must be provided if inferring sparse array shape fails for '{array_uri}'. Original error: {e}"
90
+ ) from e
91
+ self._len = num_rows if num_rows is not None else 0
92
+ self.num_columns = num_columns if num_columns is not None else 0
93
+ warn(
94
+ f"Falling back to provided or zero dimensions for sparse '{array_uri}' due to inference error: {e}",
95
+ RuntimeWarning,
96
+ )
97
+
98
+ if self.num_columns is None or self.num_columns <= 0 and self._len > 0:
99
+ raise ValueError(
100
+ f"num_columns ({self.num_columns}) is invalid or could not be determined for sparse array '{array_uri}'."
101
+ )
102
+
103
+ if self._len == 0:
104
+ warn(f"SparseDataset for '{array_uri}' has length 0.", RuntimeWarning)
105
+
106
+ def _init_worker_state(self):
107
+ if self.cell_array_instance is None:
108
+ ctx = tiledb.Ctx(self.cellarr_ctx_config) if self.cellarr_ctx_config else None
109
+ self.cell_array_instance = SparseCellArray(
110
+ uri=self.array_uri,
111
+ attr=self.attribute_name,
112
+ mode="r",
113
+ config_or_context=ctx,
114
+ return_sparse=True,
115
+ sparse_coerce=self.sparse_format,
116
+ )
117
+
118
+ def __len__(self):
119
+ return self._len
120
+
121
+ def __getitem__(self, idx):
122
+ if not 0 <= idx < self._len:
123
+ raise IndexError(f"Index {idx} out of bounds for dataset of length {self._len}.")
124
+
125
+ self._init_worker_state()
126
+
127
+ item_slice = (slice(idx, idx + 1), slice(None))
128
+
129
+ scipy_sparse_sample = self.cell_array_instance[item_slice]
130
+
131
+ if self.transform: # e.g., convert to COO for easier collation
132
+ scipy_sparse_sample = self.transform(scipy_sparse_sample)
133
+
134
+ if not isinstance(scipy_sparse_sample, sp.coo_matrix):
135
+ scipy_sparse_sample = scipy_sparse_sample.tocoo()
136
+
137
+ return scipy_sparse_sample
138
+
139
+
140
+ def sparse_coo_collate_fn(batch):
141
+ """Custom collate_fn for a batch of SciPy COO sparse matrices.
142
+
143
+ Converts them into a single batched PyTorch sparse COO tensor.
144
+
145
+ Each item in 'batch' is a SciPy coo_matrix representing one sample.
146
+ """
147
+ all_data = []
148
+ all_row_indices = []
149
+ all_col_indices = []
150
+
151
+ for i, scipy_coo in enumerate(batch):
152
+ if scipy_coo.nnz > 0:
153
+ all_data.append(torch.from_numpy(scipy_coo.data))
154
+ all_row_indices.append(torch.full_like(torch.from_numpy(scipy_coo.row), fill_value=i, dtype=torch.long))
155
+ all_col_indices.append(torch.from_numpy(scipy_coo.col))
156
+
157
+ if not all_data:
158
+ num_columns = batch[0].shape[1] if batch else 0
159
+ return torch.sparse_coo_tensor(torch.empty((2, 0), dtype=torch.long), torch.empty(0), (len(batch), num_columns))
160
+
161
+ data_cat = torch.cat(all_data)
162
+ row_indices_cat = torch.cat(all_row_indices)
163
+ col_indices_cat = torch.cat(all_col_indices)
164
+
165
+ indices = torch.stack([row_indices_cat, col_indices_cat], dim=0)
166
+ num_columns = batch[0].shape[1]
167
+ batch_size = len(batch)
168
+
169
+ sparse_tensor = torch.sparse_coo_tensor(indices, data_cat, (batch_size, num_columns))
170
+ return sparse_tensor
171
+
172
+
173
+ def construct_sparse_array_dataloader(
174
+ array_uri: str,
175
+ attribute_name: str = "data",
176
+ num_rows: Optional[int] = None,
177
+ num_columns: Optional[int] = None,
178
+ batch_size: int = 1000,
179
+ num_workers_dl: int = 2,
180
+ ) -> DataLoader:
181
+ """Construct an instance of `SparseArrayDataset` with PyTorch DataLoader.
182
+
183
+ Args:
184
+ array_uri:
185
+ URI of the TileDB array.
186
+
187
+ attribute_name:
188
+ Name of the attribute to read from.
189
+
190
+ num_rows:
191
+ The total number of rows in the TileDB array.
192
+
193
+ num_columns:
194
+ The total number of columns in the TileDB array.
195
+
196
+ batch_size:
197
+ Number of random samples per batch generated by the dataset.
198
+
199
+ num_workers_dl:
200
+ Number of worker processes for the DataLoader.
201
+ """
202
+ tiledb_ctx_config = {
203
+ "sm.tile_cache_size": 1000 * 1024**2,
204
+ "sm.num_reader_threads": 4,
205
+ }
206
+
207
+ dataset = SparseArrayDataset(
208
+ array_uri=array_uri,
209
+ attribute_name=attribute_name,
210
+ num_rows=num_rows,
211
+ num_columns=num_columns,
212
+ sparse_format=sp.coo_matrix,
213
+ cellarr_ctx_config=tiledb_ctx_config,
214
+ )
215
+
216
+ if len(dataset) == 0:
217
+ print("Dataset is empty, cannot create DataLoader.")
218
+ return
219
+
220
+ dataloader = DataLoader(
221
+ dataset,
222
+ batch_size=batch_size,
223
+ shuffle=True,
224
+ num_workers=num_workers_dl,
225
+ collate_fn=sparse_coo_collate_fn,
226
+ pin_memory=False,
227
+ persistent_workers=True if num_workers_dl > 0 else False,
228
+ )
229
+
230
+ return dataloader
@@ -0,0 +1,26 @@
1
+ import random
2
+
3
+ import numpy as np
4
+
5
+ __author__ = "Jayaram Kancherla"
6
+ __copyright__ = "Jayaram Kancherla"
7
+ __license__ = "MIT"
8
+
9
+
10
+ def seed_worker(worker_id: int):
11
+ """Generate seeds for a PyTorch DataLoader worker.
12
+
13
+ This ensures that if multiple workers are sampling randomly, they use
14
+ different sequences of random numbers.
15
+
16
+ Args:
17
+ worker_id:
18
+ The ID of the worker process.
19
+ """
20
+
21
+ import torch
22
+
23
+ worker_seed = torch.initial_seed() % 2**32
24
+ np.random.seed(worker_seed)
25
+ random.seed(worker_seed)
26
+ # print(f"Worker {worker_id} seeded with {worker_seed}")
@@ -0,0 +1,3 @@
1
+ from .config import CellArrConfig, ConsolidationConfig
2
+ from ..core.helpers import create_cellarray
3
+ # from .mock import generate_tiledb_dense_array, generate_tiledb_sparse_array