scdataloader 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +1 -1
- scdataloader/__main__.py +63 -42
- scdataloader/collator.py +87 -43
- scdataloader/config.py +106 -0
- scdataloader/data.py +78 -98
- scdataloader/datamodule.py +375 -0
- scdataloader/mapped.py +22 -7
- scdataloader/preprocess.py +444 -109
- scdataloader/utils.py +106 -63
- {scdataloader-0.0.3.dist-info → scdataloader-0.0.4.dist-info}/METADATA +46 -2
- scdataloader-0.0.4.dist-info/RECORD +16 -0
- scdataloader/dataloader.py +0 -318
- scdataloader-0.0.3.dist-info/RECORD +0 -15
- {scdataloader-0.0.3.dist-info → scdataloader-0.0.4.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-0.0.4.dist-info}/WHEEL +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-0.0.4.dist-info}/entry_points.txt +0 -0
scdataloader/mapped.py
CHANGED
|
@@ -93,6 +93,22 @@ class MappedDataset:
|
|
|
93
93
|
for storage in self.storages:
|
|
94
94
|
with _Connect(storage) as store:
|
|
95
95
|
X = store["X"]
|
|
96
|
+
index = (
|
|
97
|
+
store["var"]["ensembl_gene_id"]
|
|
98
|
+
if "ensembl_gene_id" in store["var"]
|
|
99
|
+
else store["var"]["_index"]
|
|
100
|
+
)
|
|
101
|
+
if join_vars == "None":
|
|
102
|
+
if not all(
|
|
103
|
+
[
|
|
104
|
+
i <= j
|
|
105
|
+
for i, j in zip(
|
|
106
|
+
index[:99],
|
|
107
|
+
index[1:100],
|
|
108
|
+
)
|
|
109
|
+
]
|
|
110
|
+
):
|
|
111
|
+
raise ValueError("The variables are not sorted.")
|
|
96
112
|
if isinstance(X, ArrayTypes): # type: ignore
|
|
97
113
|
self.n_obs_list.append(X.shape[0])
|
|
98
114
|
else:
|
|
@@ -179,18 +195,15 @@ class MappedDataset:
|
|
|
179
195
|
with _Connect(self.storages[storage_idx]) as store:
|
|
180
196
|
out = {"x": self.get_data_idx(store, obs_idx, var_idxs)}
|
|
181
197
|
if self.label_keys is not None:
|
|
182
|
-
for
|
|
198
|
+
for _, label in enumerate(self.label_keys):
|
|
183
199
|
label_idx = self.get_label_idx(store, obs_idx, label)
|
|
184
200
|
if label in self.encoders:
|
|
185
201
|
out.update({label: self.encoders[label][label_idx]})
|
|
186
202
|
else:
|
|
187
203
|
out.update({label: label_idx})
|
|
204
|
+
out.update({"dataset": storage_idx})
|
|
188
205
|
return out
|
|
189
206
|
|
|
190
|
-
def uns(self, idx, key):
|
|
191
|
-
storage = self.storages[self.storage_idx[idx]]
|
|
192
|
-
return storage["uns"][key]
|
|
193
|
-
|
|
194
207
|
def get_data_idx(
|
|
195
208
|
self,
|
|
196
209
|
storage: StorageType,
|
|
@@ -247,9 +260,11 @@ class MappedDataset:
|
|
|
247
260
|
else:
|
|
248
261
|
labels += "_" + self.get_merged_labels(val).astype(str).astype("O")
|
|
249
262
|
counter = Counter(labels) # type: ignore
|
|
250
|
-
|
|
263
|
+
rn = {n: i for i, n in enumerate(counter.keys())}
|
|
264
|
+
labels = np.array([rn[label] for label in labels])
|
|
265
|
+
counter = np.array(list(counter.values()))
|
|
251
266
|
weights = scaler / (counter + scaler)
|
|
252
|
-
return weights
|
|
267
|
+
return weights, labels
|
|
253
268
|
|
|
254
269
|
def get_merged_labels(self, label_key: str):
|
|
255
270
|
"""Get merged labels."""
|